1 of 28

Transformer Feed-Forward Layers Are Key-Value Memories [1]

Nils Rethmeier

January, 2021

Key insights:

  • feed-forward layers in transformer LM act as key-value memories
  • each key correlates with textual training data patterns
  • each value induces a distribution over the output vocabulary
  • lowerer layers favor syntax, higher layers favor ‘semantics’
  • FF layers increasingly refine memory composition upwards

2 of 28

3 of 28

3

Linear: weight query

Linear weight: key

2 token embeddings

= linear layers pretrained via SSL objective

cookie

monster

Linear: weight value

4 of 28

4

Linear: weight query

Linear weight: key

2 token embeddings

= linear layers pretrained via SSL objective

cookie

monster

Encoder Attention https://medium.com/@b.terryjack/deep-learning-the-transformer-9ae5e9c5a190

  • q = the current position-word vector in the input sequence → all q make up Q
  • K = all the position-word vectors in the input sequence
  • V = all the position-word vectors in the input sequence

Linear: weight value

5 of 28

5

scores for q, k combos

gradient stabilizer

self-attention scores

2 token embeddings

= linear layers pretrained via SSL objective

cookie

monster

Encoder Attention https://medium.com/@b.terryjack/deep-learning-the-transformer-9ae5e9c5a190

  • q = the current position-word vector in the input sequence → all q make up Q
  • K = all the position-word vectors in the input sequence
  • V = all the position-word vectors in the input sequence

Decoder Attention: QKV are in the output sequence

Linear: weight query

Linear weight: key

Linear: weight value

6 of 28

6

scores for q, k combos

gradient stabilizer

self-attention scores

2 token embeddings

= linear layers pretrained via SSL objective

WiQ/K/V = i is for a single SA head

cookie

monster

7 of 28

7

linear

concat

8 of 28

8

9 of 28

Key patterns, value sub-vocabs [1]

Transf. point-wise FFN layer [2]

Persistent Memory SA FNN [3]

Idea: [3] showed that [2] and [3] can learn the same. [3] is a key-val memory net. So is [1] (Transformers) a key-val memory

9

linear = W1 … 4x the size of z1

self-attention out z1= x

linear = W2

ReLU + Dropout

[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”

10 of 28

Key patterns, value sub-vocabs [1]

Transf. point-wise FFN layer [2]

Persistent Memory SA FNN [3]

Idea: [3] showed that [2] and [3] can learn the same. [3] is a key-val memory net. So is [1] (Transformers) a key-val memory

10

linear = W1 … 4x the size of z1

self-attention out z1= x

linear = W2

ReLU + Dropout

[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”

Transformer Self-attention

Memory Self-attention

11 of 28

Key patterns, value sub-vocabs [1]

Transf. point-wise FFN layer [2]

Persistent Memory SA FNN [3]

Idea: [3] showed that [2] and [3] can learn the same. [3] is a key-val memory net. So is [1] Transf. FFN a key-val memory?

11

linear = W1 … 4x the size of z1

self-attention out z1= x

linear = W2

ReLU + Dropout

[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”

Transformer Self-attention

Memory Self-attention

12 of 28

Key patterns, value sub-vocabulary

  • v are distributions over vocabulary words (topics)
  • k1..kdm are text patterns correlations/ weights

  • an input vector x5 is

multiplied by keys k1..kdm to produce a memory coefficient m2=1.5 for v2

12

12

[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”

Is FNN in [2] = [3]?

13 of 28

Key patterns, value sub-vocabulary

  • v are distributions over vocabulary words (topics)
  • k1..kdm are text patterns correlations that follow v2

  • an input vector x5 is

multiplied by keys k1..kdm to produce a memory coefficient m2=1.5 for v2

13

Is FNN in [2] = [3]?

[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”

14 of 28

Key patterns, value sub-vocabulary

  • v are distributions over next vocabulary words (topics)
  • k1..kdm are text patterns correlations that follow v2

  • an input vector x5 is

multiplied by keys k1..kdm to produce a memory

coefficient m2=1.5 for v2

14

Is FNN in [2] = [3]?

[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”

15 of 28

Key patterns, value sub-vocabulary

Approach:

  • collect ki key patterns
  • collect training set sentence prefixes S where memory coeff m is largest
  • ask humans to identify text patterns in S

15

Is FNN in [2] = [3]?

[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”

16 of 28

Key patterns, value sub-vocabulary

Approach:

  • collect ki key patterns

16

[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”

they analyze wikitext-103 transformer LM https://arxiv.org/abs/1809.10853 10 keys/ layer (160)

per key, collect 25 prefixes xi with best memory coeffs = ReLU(xi*k)

at least 3 prefixes xi per pattern

deeper layers

17 of 28

Key patterns results

Approach:

  • collect ki key patterns

Insight 1: keys ‘cluster’ patterns

  • humans could find at least one pattern per key k
  • avg 3.6 patterns per key
  • most of the 25 top prefixes belong to a pattern

Insight 2:

  • lower layers collect shallow patterns
    • prefixes often share last word
  • higher layers capture more semantic patterns

17

[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”

18 of 28

Key patterns results

Approach:

  • collect ki key patterns

Insight 1: keys ‘cluster’ patterns

  • humans could find at least one pattern per key k
  • avg 3.6 patterns per key
  • most of the 25 top prefixes belong to a pattern

Insight 2:

  • lower layers collect shallow patterns
    • prefixes often share last word
  • higher layers capture more semantic patterns

18

[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”

further confirms

19 of 28

Values = output vocab distributions

Approach:

  • collect ki key patterns
  • convert value vi to word probs where , where E is the output embedding matr.
  • get top-1 ranked word v* arg- max(p) for each dim/ layer
  • get the first word w* of the top-1 trigger example (x*, k) with maximal memory coeff m
  • measure agreement v* = w*

Insight 3:

19

[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”

20 of 28

Values = output vocab distributions

Approach:

  • collect ki key patterns
  • convert value vi to word probs where , where E is the output embedding matr.
  • get top-1 ranked word v* arg- max(p) for each dim/ layer
  • get the first word w* of the top-1 trigger example (x*, k) with maximal memory coeff m
  • measure agreement v* = w*

Insight 3:

20

[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”

21 of 28

Values = output vocab distributions

Approach:

  • collect ki key patterns
  • convert value vi to word probs where , where E is the output embedding matr.
  • get top-1 ranked word v* arg- max(p) for each dim/ layer
  • get the first word w* of the top-1 trigger example (x*, k) with maximal memory coeff m
  • measure agreement v* = w*

Insight 3:

21

[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”

22 of 28

Values = output vocab distributions

Approach:

  • collect ki key patterns
  • convert value vi to word probs where , where E is the output embedding matr.
  • get top-1 ranked word v* arg- max(p) for each dim/ layer
  • get the first word w* of the top-1 trigger example (x*, k) with maximal memory coeff m
  • measure agreement v* = w*
  • measure agreement for v* =

vs. |p| (x-axis)

Insight:

3. memory cells recall how to predict next word in higher layers

22

[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”

3.

23 of 28

Values = output vocab distributions

Approach:

  • collect ki key patterns
  • convert value vi to word probs where , where E is the output embedding matr.
  • get top-1 ranked word v* arg- max(p) for each dim/ layer
  • get the first word w* of the top-1 trigger example (x*, k) with maximal memory coeff m
  • measure agreement v* = w*
  • measure agreement for v* =

vs. |p| (x-axis)

Insight:

3. memory cells recall how to predict next word in higher layers

4. likely values v* agree more

with key patterns k

23

[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”

3.

4.

24 of 28

Inference = composed memories

24

recall, upper layers are more semantic

25 of 28

Inference = composed memories

25

recall, upper layers are more semantic

no single memory predicts the output → memory composition is required to predict outputs

26 of 28

Layer-wise prediction refinement

Residual connections r sequentially compose predictions to produce the final output olast

Insight:

5. layers refine predictions via r

6. hard decision in upper layers

26

27 of 28

Residuals and FFN compose outs

% Layer output top prediction top(olayer): matches either top prediction of:

  • the FFN
  • the residual
  • both (agree)
  • neither of them (composition)

Insight:

7. residual decides most predictions

8. FFN has almost no influence

9. residual and FNN compose the rest of layer/ model predictions

27

28 of 28

28

End