Transformer Feed-Forward Layers Are Key-Value Memories [1]
Nils Rethmeier
January, 2021
Key insights:
Recall Self-attention http://jalammar.github.io/illustrated-transformer/
Recall Self-attention http://jalammar.github.io/illustrated-transformer/
3
Linear: weight query
Linear weight: key
2 token embeddings
= linear layers pretrained via SSL objective
cookie
monster
Linear: weight value
Recall Self-attention http://jalammar.github.io/illustrated-transformer/
4
Linear: weight query
Linear weight: key
2 token embeddings
= linear layers pretrained via SSL objective
cookie
monster
Encoder Attention https://medium.com/@b.terryjack/deep-learning-the-transformer-9ae5e9c5a190
Linear: weight value
Recall Self-attention http://jalammar.github.io/illustrated-transformer/
5
scores for q, k combos
gradient stabilizer
self-attention scores
2 token embeddings
= linear layers pretrained via SSL objective
cookie
monster
Encoder Attention https://medium.com/@b.terryjack/deep-learning-the-transformer-9ae5e9c5a190
Decoder Attention: QKV are in the output sequence
Linear: weight query
Linear weight: key
Linear: weight value
Recall Self-attention http://jalammar.github.io/illustrated-transformer/
6
scores for q, k combos
gradient stabilizer
self-attention scores
2 token embeddings
= linear layers pretrained via SSL objective
WiQ/K/V = i is for a single SA head
cookie
monster
Recall Self-attention http://jalammar.github.io/illustrated-transformer/
7
linear
concat
Recall Self-attention http://jalammar.github.io/illustrated-transformer/
8
Key patterns, value sub-vocabs [1]
Transf. point-wise FFN layer [2]
Persistent Memory SA FNN [3]
Idea: [3] showed that [2] and [3] can learn the same. [3] is a key-val memory net. So is [1] (Transformers) a key-val memory
9
linear = W1 … 4x the size of z1
self-attention out z1= x
linear = W2
ReLU + Dropout
[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”
Key patterns, value sub-vocabs [1]
Transf. point-wise FFN layer [2]
Persistent Memory SA FNN [3]
Idea: [3] showed that [2] and [3] can learn the same. [3] is a key-val memory net. So is [1] (Transformers) a key-val memory
10
linear = W1 … 4x the size of z1
self-attention out z1= x
linear = W2
ReLU + Dropout
[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”
Transformer Self-attention
Memory Self-attention
Key patterns, value sub-vocabs [1]
Transf. point-wise FFN layer [2]
Persistent Memory SA FNN [3]
Idea: [3] showed that [2] and [3] can learn the same. [3] is a key-val memory net. So is [1] Transf. FFN a key-val memory?
11
linear = W1 … 4x the size of z1
self-attention out z1= x
linear = W2
ReLU + Dropout
[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”
Transformer Self-attention
Memory Self-attention
Key patterns, value sub-vocabulary
multiplied by keys k1..kdm to produce a memory coefficient m2=1.5 for v2
12
12
[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”
Is FNN in [2] = [3]?
Key patterns, value sub-vocabulary
multiplied by keys k1..kdm to produce a memory coefficient m2=1.5 for v2
13
Is FNN in [2] = [3]?
[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”
Key patterns, value sub-vocabulary
multiplied by keys k1..kdm to produce a memory
coefficient m2=1.5 for v2
14
Is FNN in [2] = [3]?
[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”
Key patterns, value sub-vocabulary
Approach:
15
Is FNN in [2] = [3]?
[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”
Key patterns, value sub-vocabulary
Approach:
16
[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”
they analyze wikitext-103 transformer LM https://arxiv.org/abs/1809.10853 10 keys/ layer (160)
per key, collect 25 prefixes xi with best memory coeffs = ReLU(xi*k)
at least 3 prefixes xi per pattern
deeper layers
Key patterns results
Approach:
Insight 1: keys ‘cluster’ patterns
Insight 2:
17
[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”
Key patterns results
Approach:
Insight 1: keys ‘cluster’ patterns
Insight 2:
18
[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”
further confirms
Values = output vocab distributions
Approach:
Insight 3:
19
[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”
Values = output vocab distributions
Approach:
Insight 3:
20
[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”
Values = output vocab distributions
Approach:
Insight 3:
21
[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”
Values = output vocab distributions
Approach:
vs. |p| (x-axis)
Insight:
3. memory cells recall how to predict next word in higher layers
22
[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”
3.
Values = output vocab distributions
Approach:
vs. |p| (x-axis)
Insight:
3. memory cells recall how to predict next word in higher layers
4. likely values v* agree more
with key patterns k
23
[2] “Attentions is all you Need”, [3] “Augmenting Self-attention with Persistent Memory”
3.
4.
Inference = composed memories
24
recall, upper layers are more semantic
Inference = composed memories
25
recall, upper layers are more semantic
no single memory predicts the output → memory composition is required to predict outputs
Layer-wise prediction refinement
Residual connections r sequentially compose predictions to produce the final output olast
Insight:
5. layers refine predictions via r
6. hard decision in upper layers
26
Residuals and FFN compose outs
% Layer output top prediction top(olayer): matches either top prediction of:
Insight:
7. residual decides most predictions
8. FFN has almost no influence
9. residual and FNN compose the rest of layer/ model predictions
27
28
End