Graph transformers
Oct 11th, 2022
BMI/CS 775 Computational Network Biology�Fall 2022
Anthony Gitter
Topics in this section
Goals for today
Supervised graph prediction task
Limitations of graph neural networks: under-reach
Image adapted from Wu et al. 2019 A Comprehensive Survey on Graph Neural Networks
Require k graph layers to share information k edges away
Cannot learn long distance relationships
i
Layer 1
Layer 2
Limitations of graph neural networks: over-smoothing
1
3
2
4
Input graph
0.1 | 7.8 | 3.4 | 2.2 |
9.4 | 1.1 | 4.1 | 8.7 |
8.3 | 2.1 | 1.5 | 3.3 |
3.3 | 4.0 | 5.6 | 2.9 |
6.9 | 5.8 | 2.3 | 4.7 |
7.6 | 4.1 | 2.3 | 5.7 |
8.2 | 2.3 | 1.7 | 3.6 |
4.4 | 3.5 | 4.1 | 2.8 |
6.1 | 3.2 | 3.0 | 7.8 |
6.1 | 3.2 | 3.0 | 7.8 |
6.1 | 3.2 | 3.0 | 7.8 |
6.1 | 3.2 | 3.0 | 7.8 |
GNN layer 1
GNN layer 50
GNN layer 2
…
1
2
3
4
Limitations of graph neural networks: over-squashing
Also known as bottlenecks
Image from Alon and Yahav ICLR 2021
Label green node from {A, B, C} based on number of neighbors
General graph
Limitations of graph neural networks: over-squashing
Images from Alon and Yahav ICLR 2021
Accuracy drops for r > 4 or 5
Evaluate different graph neural network variants
Information from all leaves must be aggregated in top node’s representation
Test case: binary tree with depth r, predict label of root
Limitations of graph neural networks: expressiveness
Image from PubChem
Image from GraphGPS blog post
Decalin molecule
Molecular graph
Standard graph neural network representation
Ideal representation
Improving graph neural networks
GraphGPS design
Image from GraphGPS blog post
Derived node and edge features identifying positions of graph elements
Global attention
Layers combining graph neural network updates with global attention
Attention Is All You Need
Attention concepts: queries, keys, values
Attention concepts: queries, keys, values
mammal
lizard
salmon
whale
wolf
scaly
slippery
huge
ferocious
kitten
furry
Attention concepts: queries, keys, values
lizard
salmon
whale
wolf
scaly
slippery
huge
ferocious
0.01
0.02
0.29
0.33
mammal
kitten
furry
0.35
Attention concepts: queries, keys, values
lizard
salmon
whale
wolf
scaly
slippery
huge
ferocious
0.01
0.02
0.29
0.33
mammal
kitten
furry
0.35
Image from Fat Bear Week 2022
Attention concepts: queries, keys, values
mammal
lizard
salmon
whale
wolf
scaly
slippery
huge
ferocious
kitten
furry
2.2 | 9.4 | 8.7 | 1.1 |
1.6 | 8.4 | 9.9 | 3.4 |
8.3 | 2.8 | 2.3 | 0.1 |
7.5 | 3.9 | 4.1 | 0.2 |
3.4 | 1.3 | 0.4 | 9.8 |
8.7 | 3.2 | 4.1 |
9.1 | 1.0 | 2.1 |
0.1 | 7.5 | 4.3 |
1.3 | 5.5 | 8.2 |
7.6 | 2.4 | 4.0 |
8.5 | 2.7 | 2.7 |
Attention concepts: queries, keys, values
mammal
lizard
salmon
whale
wolf
scaly
slippery
huge
ferocious
kitten
furry
2.2 | 9.4 | 8.7 | 1.1 |
1.6 | 8.4 | 9.9 | 3.4 |
8.3 | 2.8 | 2.3 | 0.1 |
7.5 | 3.9 | 4.1 | 0.2 |
3.4 | 1.3 | 0.4 | 9.8 |
8.7 | 3.2 | 4.1 |
9.1 | 1.0 | 2.1 |
0.1 | 7.5 | 4.3 |
1.3 | 5.5 | 8.2 |
7.6 | 2.4 | 4.0 |
8.5 | 2.7 | 2.7 |
62.53
42.50
90.20
93.66
90.98
Attention concepts: queries, keys, values
mammal
lizard
salmon
whale
wolf
scaly
slippery
huge
ferocious
kitten
furry
2.2 | 9.4 | 8.7 | 1.1 |
1.6 | 8.4 | 9.9 | 3.4 |
8.3 | 2.8 | 2.3 | 0.1 |
7.5 | 3.9 | 4.1 | 0.2 |
3.4 | 1.3 | 0.4 | 9.8 |
8.7 | 3.2 | 4.1 |
9.1 | 1.0 | 2.1 |
0.1 | 7.5 | 4.3 |
1.3 | 5.5 | 8.2 |
7.6 | 2.4 | 4.0 |
8.5 | 2.7 | 2.7 |
36.10
24.54
52.08
54.07
50.53
Attention concepts: queries, keys, values
mammal
lizard
salmon
whale
wolf
scaly
slippery
huge
ferocious
kitten
furry
2.2 | 9.4 | 8.7 | 1.1 |
1.6 | 8.4 | 9.9 | 3.4 |
8.3 | 2.8 | 2.3 | 0.1 |
7.5 | 3.9 | 4.1 | 0.2 |
3.4 | 1.3 | 0.4 | 9.8 |
8.7 | 3.2 | 4.1 |
9.1 | 1.0 | 2.1 |
0.1 | 7.5 | 4.3 |
1.3 | 5.5 | 8.2 |
7.6 | 2.4 | 4.0 |
8.5 | 2.7 | 2.7 |
0.00
0.00
0.10
0.74
0.16
Attention concepts: queries, keys, values
mammal
lizard
salmon
whale
wolf
scaly
slippery
huge
ferocious
kitten
furry
2.2 | 9.4 | 8.7 | 1.1 |
1.6 | 8.4 | 9.9 | 3.4 |
8.3 | 2.8 | 2.3 | 0.1 |
7.5 | 3.9 | 4.1 | 0.2 |
3.4 | 1.3 | 0.4 | 9.8 |
8.7 | 3.2 | 4.1 |
9.1 | 1.0 | 2.1 |
0.1 | 7.5 | 4.3 |
1.3 | 5.5 | 8.2 |
7.6 | 2.4 | 4.0 |
8.5 | 2.7 | 2.7 |
0.00
0.00
0.10
0.74
0.16
Output
2.3 | 8.0 | 7.5 | 2.7 |
Attention so far
Scaled dot-product attention
Top Hat question
Self-attention
sits
in
the
classroom
student
sits
in
the
classroom
student
sits
in
the
classroom
student
Self-attention
sits
in
the
classroom
student
sits
in
the
classroom
student
0.51 | 0.12 | 0.03 | 0.01 | 0.33 |
0.12 | 0.62 | 0.19 | 0.03 | 0.04 |
0.03 | 0.19 | 0.59 | 0.13 | 0.06 |
0.01 | 0.03 | 0.13 | 0.68 | 0.15 |
0.33 | 0.04 | 0.06 | 0.15 | 0.42 |
Self-attention
sits
in
the
classroom
student
sits
in
the
classroom
student
0.51 | 0.12 | 0.03 | 0.01 | 0.33 |
0.12 | 0.62 | 0.19 | 0.03 | 0.04 |
0.03 | 0.19 | 0.59 | 0.13 | 0.06 |
0.01 | 0.03 | 0.13 | 0.68 | 0.15 |
0.33 | 0.04 | 0.06 | 0.15 | 0.42 |
sits
in
the
classroom
student
Updated embedding for student will be this row of the attention matrix times current values embeddings
Self-attention still isn’t enough
Multi-headed attention
| | | |
| | | |
| |
| |
| |
| |
| | |
| | |
| | |
| | |
| | | |
| |
| |
| |
| |
Multi-headed attention
Scaled dot-product attention
Multi-headed self-attention
sits
in
the
classroom
student
sits
in
the
classroom
student
0.51 | 0.01 | 0.03 | 0.01 | 0.44 |
0.05 | 0.86 | 0.02 | 0.03 | 0.04 |
0.01 | 0.10 | 0.76 | 0.07 | 0.06 |
0.12 | 0.03 | 0.02 | 0.68 | 0.15 |
0.55 | 0.04 | 0.01 | 0.02 | 0.38 |
Multi-headed self-attention
sits
in
the
classroom
student
sits
in
the
classroom
student
0.08 | 0.91 | 0.00 | 0.00 | 0.01 |
0.02 | 0.92 | 0.02 | 0.03 | 0.01 |
0.01 | 0.02 | 0.87 | 0.07 | 0.03 |
0.05 | 0.03 | 0.02 | 0.90 | 0.00 |
0.05 | 0.10 | 0.01 | 0.02 | 0.82 |
Multi-headed self-attention
Attention mechanism operates on sets
sits
in
the
classroom
student
the
in
student
sits
classroom
0.05 | 0.10 | 0.01 | 0.02 | 0.82 |
0.05 | 0.03 | 0.02 | 0.90 | 0.00 |
0.01 | 0.02 | 0.87 | 0.07 | 0.03 |
0.08 | 0.91 | 0.00 | 0.00 | 0.01 |
0.02 | 0.92 | 0.02 | 0.03 | 0.01 |
Positional encodings
Position
Index
Image from Wikipedia (Cosmia Nebula)
Building the complete Transformer encoder
Initial shallow embedding
Positional encoding
Multi-headed self-attention
Residual connections help training
Layer norm helps training
Update embedding for each word
Repeat N of these blocks
Adapting Transformers for graphs
Positional and structural encodings
Positional encodings
Structural encodings
Where am I?
What is my neighborhood like?
Positional and structural encodings
Image from GraphGPS blog post
Positional and structural encodings
Encoding | Example |
Local positional | Distance between node and cluster center |
Global positional | Distance from graph center |
Relative positional | Shortest path node distance |
Local structural | Node degree |
Global structural | Number of graph edges |
Relative structural | Boolean indicating nodes are in same substructure |
GPS layers
Image from GraphGPS blog post
Message passing neural network (MPNN)
Recall: graph convolutional neural network
Edge features in graph convolutions
Modular graph convolutions
Graph Transformer
Attention matrix size
sits
in
the
classroom
student
sits
in
the
classroom
student
0.51 | 0.12 | 0.03 | 0.01 | 0.33 |
0.12 | 0.62 | 0.19 | 0.03 | 0.04 |
0.03 | 0.19 | 0.59 | 0.13 | 0.06 |
0.01 | 0.03 | 0.13 | 0.68 | 0.15 |
0.33 | 0.04 | 0.06 | 0.15 | 0.42 |
Scaling global attention with Performer
Image from Performer blog post
Scaling global attention with Performer
Image from Performer blog post
Original attention
Performer approximation
Scaling global attention with Performer
Image from Performer blog post
Transformer runs out of memory on V100 GPU
Constructing the GPS layer
Constructing the GPS layer
Batch normalization
Residual connection
Intermediate updated node features
Constructing the GPS layer
Residual connection
Intermediate updated node features
Batch normalization
Constructing the GPS layer
Residual connection
Batch normalization
Add intermediate node features
Multi-layer perceptron
Constructing the GPS layer
Same operations as a sequence of functions
Evaluating GraphGPS
One example dataset
C
C
C
C
C
C
C
C
C
C
Evaluation results
Colors indicate first, second, and third best results
Insights from evaluations
Challenges of GraphGPS
GraphGPS recap
Image from GraphGPS blog post
Conclusions
References
Resources