1 of 63

Graph transformers

Oct 11th, 2022

BMI/CS 775 Computational Network Biology�Fall 2022

Anthony Gitter

https://compnetbiocourse.discovery.wisc.edu

2 of 63

  • Representation learning on graphs
  • Graph neural networks
  • Graph transformers
  • Generative graph models

Topics in this section

3 of 63

Goals for today

  • Limitations of graph neural networks
  • Transformers and self-attention
  • Adapting Transformers for graphs
  • General, Powerful, Scalable Graph Transformer

4 of 63

Supervised graph prediction task

  • Given:
    • Graph structure Gi = (Vi , Ei )
    • Optional features on the nodes Vi
    • Optional features on the edges Ei
  • Do:
    • Predict a label yi for the graph (or nodes or edges)
    • Can be discrete (classification) or continuous (regression)

  • Goal:
    • Use the node relationships to influence the predictions
    • Overcome limitations of graph neural networks

5 of 63

Limitations of graph neural networks: under-reach

Image adapted from Wu et al. 2019 A Comprehensive Survey on Graph Neural Networks

Require k graph layers to share information k edges away

Cannot learn long distance relationships

i

Layer 1

Layer 2

6 of 63

Limitations of graph neural networks: over-smoothing

1

3

2

4

Input graph

 

0.1

7.8

3.4

2.2

9.4

1.1

4.1

8.7

8.3

2.1

1.5

3.3

3.3

4.0

5.6

2.9

6.9

5.8

2.3

4.7

7.6

4.1

2.3

5.7

8.2

2.3

1.7

3.6

4.4

3.5

4.1

2.8

6.1

3.2

3.0

7.8

6.1

3.2

3.0

7.8

6.1

3.2

3.0

7.8

6.1

3.2

3.0

7.8

 

GNN layer 1

 

GNN layer 50

GNN layer 2

1

2

3

4

7 of 63

Limitations of graph neural networks: over-squashing

Also known as bottlenecks

Image from Alon and Yahav ICLR 2021

Label green node from {A, B, C} based on number of neighbors

General graph

8 of 63

Limitations of graph neural networks: over-squashing

Images from Alon and Yahav ICLR 2021

Accuracy drops for r > 4 or 5

Evaluate different graph neural network variants

Information from all leaves must be aggregated in top node’s representation

Test case: binary tree with depth r, predict label of root

9 of 63

Limitations of graph neural networks: expressiveness

Image from PubChem

Image from GraphGPS blog post

Decalin molecule

Molecular graph

Standard graph neural network representation

Ideal representation

10 of 63

Improving graph neural networks

  • Desired attributes of an improved graph neural network
    • Avoid over-smoothing and over-squashing
    • Improve expressivity
    • Modularity to adapt to different tasks
    • Scalability to 1,000s of nodes

  • General, Powerful, Scalable Graph Transformer (GraphGPS)

11 of 63

GraphGPS design

Image from GraphGPS blog post

Derived node and edge features identifying positions of graph elements

Global attention

Layers combining graph neural network updates with global attention

12 of 63

Attention Is All You Need

  • 2017 paper introduced the Transformer architecture
  • Originally for machine translation
    • “Transform” an English sentence to German
  • Unlike recurrent neural network, don’t process input text word by word
    • Model word-word dependencies in one pass in parallel
  • Success of self-attention has spread to other tasks and domains
    • Text generation
    • Biological sequences
    • Vision
    • Speech
    • Graphs

13 of 63

Attention concepts: queries, keys, values

  • Attention function maps query to an updated output
    • Uses key-value pairs to produce the output
    • Compare query to keys and compute similarities
    • Produce output as weighted sum of respective values
  • Useful for learning meaningful word embeddings based on relationships to other words
  • Building intuition for attention
    • Words
    • Word embeddings (vectors)
    • Graph node embeddings
  • See Jupyter notebook for example calculations

14 of 63

Attention concepts: queries, keys, values

  • Building intuition for attention
    • Words: update the representation of mammal
  • Have a key-value pair dictionary, one or many queries

 

mammal

 

lizard

salmon

whale

wolf

 

scaly

slippery

huge

ferocious

kitten

furry

15 of 63

Attention concepts: queries, keys, values

  •  

 

 

lizard

salmon

whale

wolf

 

scaly

slippery

huge

ferocious

0.01

0.02

0.29

0.33

mammal

kitten

furry

0.35

16 of 63

Attention concepts: queries, keys, values

  •  

 

 

lizard

salmon

whale

wolf

 

scaly

slippery

huge

ferocious

0.01

0.02

0.29

0.33

mammal

kitten

furry

0.35

17 of 63

Attention concepts: queries, keys, values

  •  

 

mammal

 

lizard

salmon

whale

wolf

 

scaly

slippery

huge

ferocious

kitten

furry

2.2

9.4

8.7

1.1

1.6

8.4

9.9

3.4

8.3

2.8

2.3

0.1

7.5

3.9

4.1

0.2

3.4

1.3

0.4

9.8

8.7

3.2

4.1

9.1

1.0

2.1

0.1

7.5

4.3

1.3

5.5

8.2

7.6

2.4

4.0

8.5

2.7

2.7

18 of 63

Attention concepts: queries, keys, values

  •  

 

mammal

 

lizard

salmon

whale

wolf

 

scaly

slippery

huge

ferocious

kitten

furry

2.2

9.4

8.7

1.1

1.6

8.4

9.9

3.4

8.3

2.8

2.3

0.1

7.5

3.9

4.1

0.2

3.4

1.3

0.4

9.8

8.7

3.2

4.1

9.1

1.0

2.1

0.1

7.5

4.3

1.3

5.5

8.2

7.6

2.4

4.0

8.5

2.7

2.7

62.53

42.50

90.20

93.66

90.98

19 of 63

Attention concepts: queries, keys, values

  •  

 

mammal

 

lizard

salmon

whale

wolf

 

scaly

slippery

huge

ferocious

kitten

furry

2.2

9.4

8.7

1.1

1.6

8.4

9.9

3.4

8.3

2.8

2.3

0.1

7.5

3.9

4.1

0.2

3.4

1.3

0.4

9.8

8.7

3.2

4.1

9.1

1.0

2.1

0.1

7.5

4.3

1.3

5.5

8.2

7.6

2.4

4.0

8.5

2.7

2.7

36.10

24.54

52.08

54.07

50.53

20 of 63

Attention concepts: queries, keys, values

  •  

 

mammal

 

lizard

salmon

whale

wolf

 

scaly

slippery

huge

ferocious

kitten

furry

2.2

9.4

8.7

1.1

1.6

8.4

9.9

3.4

8.3

2.8

2.3

0.1

7.5

3.9

4.1

0.2

3.4

1.3

0.4

9.8

8.7

3.2

4.1

9.1

1.0

2.1

0.1

7.5

4.3

1.3

5.5

8.2

7.6

2.4

4.0

8.5

2.7

2.7

0.00

0.00

0.10

0.74

0.16

 

21 of 63

Attention concepts: queries, keys, values

  •  

 

mammal

 

lizard

salmon

whale

wolf

 

scaly

slippery

huge

ferocious

kitten

furry

2.2

9.4

8.7

1.1

1.6

8.4

9.9

3.4

8.3

2.8

2.3

0.1

7.5

3.9

4.1

0.2

3.4

1.3

0.4

9.8

8.7

3.2

4.1

9.1

1.0

2.1

0.1

7.5

4.3

1.3

5.5

8.2

7.6

2.4

4.0

8.5

2.7

2.7

0.00

0.00

0.10

0.74

0.16

 

Output

2.3

8.0

7.5

2.7

22 of 63

Attention so far

  • Can calculate output values for a query as a weighted combination of existing values based on similarity to their keys

  • Not all that useful yet
  • Need two extensions
    • Self-attention
    • Multi-head attention

Scaled dot-product attention

23 of 63

Top Hat question

24 of 63

Self-attention

  •  

 

 

 

sits

in

the

classroom

student

sits

in

the

classroom

student

sits

in

the

classroom

student

25 of 63

Self-attention

  •  

 

 

sits

in

the

classroom

student

sits

in

the

classroom

student

0.51

0.12

0.03

0.01

0.33

0.12

0.62

0.19

0.03

0.04

0.03

0.19

0.59

0.13

0.06

0.01

0.03

0.13

0.68

0.15

0.33

0.04

0.06

0.15

0.42

26 of 63

Self-attention

 

 

sits

in

the

classroom

student

sits

in

the

classroom

student

0.51

0.12

0.03

0.01

0.33

0.12

0.62

0.19

0.03

0.04

0.03

0.19

0.59

0.13

0.06

0.01

0.03

0.13

0.68

0.15

0.33

0.04

0.06

0.15

0.42

 

sits

in

the

classroom

student

Updated embedding for student will be this row of the attention matrix times current values embeddings

27 of 63

Self-attention still isn’t enough

  •  

28 of 63

Multi-headed attention

  •  

 

 

 

 

 

 

 

 

 

 

 

 

29 of 63

Multi-headed attention

  •  

Scaled dot-product attention

 

30 of 63

Multi-headed attention

31 of 63

Multi-headed self-attention

  •  

 

 

sits

in

the

classroom

student

sits

in

the

classroom

student

0.51

0.01

0.03

0.01

0.44

0.05

0.86

0.02

0.03

0.04

0.01

0.10

0.76

0.07

0.06

0.12

0.03

0.02

0.68

0.15

0.55

0.04

0.01

0.02

0.38

32 of 63

Multi-headed self-attention

  •  

 

 

sits

in

the

classroom

student

sits

in

the

classroom

student

0.08

0.91

0.00

0.00

0.01

0.02

0.92

0.02

0.03

0.01

0.01

0.02

0.87

0.07

0.03

0.05

0.03

0.02

0.90

0.00

0.05

0.10

0.01

0.02

0.82

33 of 63

Multi-headed self-attention

  • Now have a powerful system for learning trainable word embeddings
  • Updates embeddings based on many types of word-word relationships
  • Don’t need to know them in advance, discover them in the training data

  • However, still operating on sets not sequences

34 of 63

Attention mechanism operates on sets

  • Reorder the query
  • The rows of the attention matrix are reordered but nothing changes
  • Not using sequence relationships

 

 

sits

in

the

classroom

student

the

in

student

sits

classroom

0.05

0.10

0.01

0.02

0.82

0.05

0.03

0.02

0.90

0.00

0.01

0.02

0.87

0.07

0.03

0.08

0.91

0.00

0.00

0.01

0.02

0.92

0.02

0.03

0.01

35 of 63

Positional encodings

  •  

Position

Index

36 of 63

Building the complete Transformer encoder

Initial shallow embedding

Positional encoding

Multi-headed self-attention

Residual connections help training

Layer norm helps training

Update embedding for each word

Repeat N of these blocks

37 of 63

Adapting Transformers for graphs

  • What will be the positional encodings?
  • How will we use the edges?
  • How will we scale to large(r) graphs?

38 of 63

Positional and structural encodings

  • GraphGPS uses two types of encodings

Positional encodings

Structural encodings

Where am I?

What is my neighborhood like?

39 of 63

Positional and structural encodings

  • Positional and structural encodings can be global, local, or relative

Image from GraphGPS blog post

40 of 63

Positional and structural encodings

  • GraphGPS supports many encodings
  • Focus on the simplest ones that do not require eigenvectors and eigenvalues

Encoding

Example

Local positional

Distance between node and cluster center

Global positional

Distance from graph center

Relative positional

Shortest path node distance

Local structural

Node degree

Global structural

Number of graph edges

Relative structural

Boolean indicating nodes are in same substructure

41 of 63

GPS layers

  • Update node and edge embeddings by combining graph neural network and transformer concepts

Image from GraphGPS blog post

Message passing neural network (MPNN)

42 of 63

Recall: graph convolutional neural network

 

 

 

43 of 63

Edge features in graph convolutions

  • Extend graph neural network update to include edge features and additional weight matrix

 

 

44 of 63

Modular graph convolutions

  • GraphGPS supports multiple specific types of graph neural network updates
    • GatedGCN
    • GINE
    • PNA

45 of 63

Graph Transformer

  • No explicit knowledge of the graph edges
  • Implicitly aware of them through the node embeddings
  • Theoretical argument why this works

  • However, Transformers don’t scale well to 1,000s of nodes

46 of 63

Attention matrix size

  •  

 

 

sits

in

the

classroom

student

sits

in

the

classroom

student

0.51

0.12

0.03

0.01

0.33

0.12

0.62

0.19

0.03

0.04

0.03

0.19

0.59

0.13

0.06

0.01

0.03

0.13

0.68

0.15

0.33

0.04

0.06

0.15

0.42

47 of 63

Scaling global attention with Performer

Image from Performer blog post

48 of 63

Scaling global attention with Performer

Image from Performer blog post

Original attention

Performer approximation

49 of 63

Scaling global attention with Performer

Image from Performer blog post

Transformer runs out of memory on V100 GPU

50 of 63

Constructing the GPS layer

 

 

51 of 63

Constructing the GPS layer

Batch normalization

Residual connection

Intermediate updated node features

52 of 63

Constructing the GPS layer

Residual connection

Intermediate updated node features

Batch normalization

53 of 63

Constructing the GPS layer

Residual connection

Batch normalization

Add intermediate node features

Multi-layer perceptron

54 of 63

Constructing the GPS layer

Same operations as a sequence of functions

55 of 63

Evaluating GraphGPS

  • Graph datasets from two benchmarks
    • Benchmarking GNNs
    • Open Graph Benchmark
  • Compare to many types of prior models
    • Graph neural networks
    • Graph transformers
  • Ablation study to assess the impact of model components
    • Global attention module
    • Message passing module
    • Positional and structural encodings

56 of 63

One example dataset

  • ZINC dataset of molecular graphs
    • Nodes are heavy atoms, edges are bonds
    • 12,000 undirected graphs with 9-37 nodes
    • Regression task: estimated constrained solubility
    • Evaluate with Mean Absolute Error (MAE)

C

C

C

C

C

C

C

C

C

C

57 of 63

Evaluation results

Colors indicate first, second, and third best results

58 of 63

Insights from evaluations

  • GraphGPS gives top or near top performance on many benchmarking tasks
  • Transformer almost always helps, except for ZINC
    • Score only depends on local structure
  • Message passing module is essential, PNA works well
  • Best encodings are dataset dependent
  • On 5,000 node graphs Transformer slightly better than Performer but twice as slow

59 of 63

Challenges of GraphGPS

  • Complex models are more compute intensive
  • Longer to train but also more hyperparameters to tune

60 of 63

GraphGPS recap

Image from GraphGPS blog post

61 of 63

Conclusions

  • GraphGPS combines many modern deep learning ideas for learning on graphs
  • Transformer and suitable encodings can overcome expressivity limitations of graph neural networks
    • Self-attention has become a very powerful tool across many domains
  • GraphGPS has excellent performance across graph benchmarks
  • The power and modularity bring some practical challenges in training and hyperparameter tuning
    • These are surmountable

62 of 63

References

63 of 63

Resources