1 of 38

The Devil is in the Detail

Simple Tricks Improve Systematic Generalization of Transformers

Róbert Csordás, Kazuki Irie, Jürgen Schmidhuber

EMNLP 2021

2 of 38

Systematic generalization

Probably one of the major obstacles toward general AI

  • The ability to learn a solution that performs well on a systematically different inputs, e.g:
    • Novel combination of known constituents (systematicity)
    • Generalization to longer problems (productivity)
  • Learning generally applicable rules instead of pure pattern matching

1+2

3

3*3

9

(1+2)*3

(1+1)*2

4

?

3 of 38

How do we measure it?

  • Synthetic datasets
  • Systematically different data distributions for testing

Train set

Test set

1+2

3

3*3

9

(1+1)*2

4

(1+2)*3

?

Train set

Test set

4 of 38

How do we measure it?

  • SCAN[Lake and Baroni, 2018]
    • jump twice JUMP JUMP
  • CFQ[Keysers et al., 2020]
    • Q: Was M0 a director and producer of M1
    • A: SELECT count(*)WHERE {M0 ns:film.director.film M1. M0 ns:film.producer.film |ns:film.production_company.films M1}.
  • PCFG[Hupkes et al., 2020]
    • reverse copy O14 O4 C12 J14 W3 W3 J14 C12 O4 O14
  • COGS[Kim and Linzen, 2020]
    • The puppy slept. *puppy ( x _ 1 ) ; sleep . agent( x _ 2, x _ 1 ).
  • Mathematics dataset[Saxton et al., 2019]
    • What is -5 - 110911? -110916

5 of 38

Existing methods

  • Simply training NNs
  • Meta-learning
  • Neuro-symbolic models

6 of 38

Existing methods

  • Plain NNs are often reported to perform very poorly

7 of 38

Existing methods

  • Meta-learning improves regular neural training, but is far from perfect
  • Underexplored

8 of 38

Existing methods

  • Neuro-symbolic methods perform perfectly on some tasks
  • But they are task-specific

9 of 38

Revisiting the basic Transformers

  • Transformers have a structure seemingly well-suited for the task
  • They should be able to build a computation graph in their layers

10 of 38

Revisiting the basic Transformers

  • Transformers have a structure seemingly well-suited for the task
  • They should be able to build a computation graph in their layers

So why do they perform so badly?

11 of 38

Do the current results reflect the full potential of NNs?

  • Many factors largely influence performance of neural networks:

- basic model configurations

- training details, hyper-parameter tuning, …

  • Very often: the default settings from the machine translation benchmark are used without modification

We revisit the basic model and training configurations.

Are the current SoTA settings optimal?

Have all relevant, existing techniques been tested?

12 of 38

Underexplored Transformer augmentations

  • Parameter sharing across layers (Universal Transformers)
  • Relative positional encodings

They are relevant for systematic generalization

13 of 38

Hypothesis 1

  • Decomposing the problem into elementary, reusable components should boost generalization

+

*

a

b

c

(a+b)*c

M

a

b

c

(a+b)*c

*

a

b

a*b

+

*

14 of 38

Hypothesis 1

  • In Transformers, the output of an operation is available only to the next layer
  • For composition, all levels should have all functions available
  • To enable compositions of elementary functions in any arbitrary orders

operation 1

operation 2

...

15 of 38

Hypothesis 1

  • In Transformers, the output of an operation is available only to the next layer
  • For composition, all levels should have all functions available
  • To enable compositions of elementary functions in any arbitrary orders

The layers should be shared

Universal Transformers

16 of 38

Hypothesis 2

  • Long compositions are often made of multiple local compositions

17 of 38

Hypothesis 2

  • Long compositions are often made of multiple local compositions

Use relative positional encodings

18 of 38

The EOS decision problem

  • A partial reason behind the bad performance on the SCAN length split
  • Newman et al.: The EOS Decision and Length Extrapolation (2020)
  • Training without EOS token improves performance compared to oracle-length evaluation

From Newman et al, 2020

19 of 38

The EOS decision problem

  • Using relative positional encodings, universal transformers and +1 layer solves the problem*

*length cutoff of 26 is interesting because of certain biases in SCAN. See Newman et al, 2020 for more details

20 of 38

Revise and improve the basics

  • Problem of validating models on IID dataset

- Especially problematic for early stopping

- Bad correlation between IID validation loss and test accuracy

  • Scaling of embeddings

21 of 38

  1. IID validation set does not tell us much
  • Common practice: use IID validation set for model selection and hyperparameter tuning
  • Sometimes even early stopping, like for baselines on COGS dataset
  • But usually IID data has weak signal about OOD performance

IID performance of different models on different datasets. OOD performance in parenthesis

22 of 38

2. Early stopping using the IID validation is sub-optimal

  • A particularly interesting case is early stopping, illustrated on COGS dataset
  • Disabling it in the baseline brings the performance from 35% to 65%

Here would the training stop if using early stopping

23 of 38

3. Loss is not a good indicator of accuracy

  • Even OOD validation loss can be positively correlated with OOD valid/test accuracy

CFQ MCD 1 dataset. Color: train iteration

24 of 38

3. Loss is not a good indicator of accuracy (cont’d)

  • The trend confirmed on the COGS dataset

25 of 38

3. Loss is not a good indicator of accuracy (cont’d)

  • Why?
  • Decompose the loss:
    • “Good” examples: classified right at least once during training
    • “Bad” examples: never classified right

iter

Good

Good

Bad

26 of 38

3. Loss is not a good indicator of accuracy (cont’d)

  • The loss for “bad” examples outgrows the improvement of “good” examples
  • Basically overfitting to a certain problem type and giving up on others

27 of 38

Validation loss is not a good indicator of accuracy

  • Early stopping based on loss is dangerous even when having data from the same distribution as the test set because of epoch-wise double descent

Test loss and accuracy on PCFG

28 of 38

What can we fix this issue?

  • This calls for more careful data splits used for development and validation
  • For example
    • Use multiple difficulty levels
      • Train: easy (e.g. length 1-5)
      • Validation: medium (e.g. length 6)
      • Test: hard (e.g. length 7-9)
    • Use an ensemble of datasets targeting identical problems
      • Some for training, some for testing

?

29 of 38

The scaling of embeddings

  • Relevant only for absolute-positional case
  • Variants
    • Token Embedding Upscaling (TEU) - Vaswani et al. (2017)
      • Use Xavier initialization for word embeddings, scale up before adding positional embeddings

    • No scaling
      • Initialize word embeddings to N(0,1), add positional embeddings

    • Position Embedding Downscaling (PED)
      • Initialize word embeddings with Kaiming, scale positional embeddings down by

30 of 38

The scaling of embeddings

  • We found Position Embedding Downscaling (PED) variant outperforming others
  • Even though it causes the states in early layers to have less than 1 std.

31 of 38

Putting it together

32 of 38

Putting it together

  • SCAN
    • 3 layer Universal Transformer with relative positional encoding
    • 9% 100% for length cutoff of 26

IN: walk twice after look opposite left OUT: I_TURN_LEFT I_TURN_LEFT I_LOOK I_WALK I_WALK

IN: run twice after look OUT: I_LOOK I_RUN I_RUN

IN: jump right and jump right OUT: I_TURN_RIGHT I_JUMP I_TURN_RIGHT I_JUMP

33 of 38

Putting it together

  • CFQ
    • Universal Transformers: 64% 77% on length split
    • Universal Transformers with relative positional encoding: 81%

"Did Debora Caprioglio marry The Night Heaven Fell's German writer"

"SELECT count(*) WHERE { ?x0 ns:film.writer.film ns:m.02x9q6y . ?x0 ns:people.person.nationality ns:m.0345h . FILTER ( ns:m.02qj61v != ?x0 ) . ns:m.02qj61v ns:people.person.spouse_s/ns:people.marriage.spouse|ns:fictional_universe.fictional_character.married_to/ns:fictional_universe.marriage_of_fictional_characters.spouses ?x0 }",

34 of 38

Putting it together

  • PCFG
    • ~237 epochs (300k iterations) instead of 25 epochs
      • Productivity split: 50% 65%
      • Systematicity split: 72% 87%
    • Universal Transformers with relative positional encoding
      • Productivity split: 85%
      • Systematicity split: 96%

IN: echo shift remove_second K15 K16 T9 , F16 A13 A2 Y6 OUT: K16 T9 K15 K15

IN: copy remove_second shift R8 N18 H14 D8 D3 , R5 R17 P8 R12 B4 OUT: N18 H14 D8 D3 R8

IN: reverse shift reverse X14 L16 O6 G3 OUT: G3 X14 L16 O6

35 of 38

Putting it together

  • COGS
    • Disabling early stopping: 35% 65%
    • Embedding scaling, no label smoothing, fixed learning rate: 80%
    • Relative positional encoding: 81%

IN: Emma cleaned the boy .

OUT: * boy ( x _ 3 ) ; clean . agent ( x _ 1 , Emma ) AND clean . theme ( x _ 1 , x _ 3 )

IN: A dog scoffed .

OUT: dog ( x _ 1 ) AND scoff . agent ( x _ 2 , x _ 1 )

36 of 38

Putting it together

  • Mathematics dataset
    • Universal Transformers with relative positional encoding
      • add_or_sub: 91% 97%
      • place_value: 69% 71%
    • Note: the scores are not directly comparable because in the original paper all modules are mixed together, benefitting from transfer learning
      • Our baselines trained on single modules usually perform worse than the baseline
      • However the Relative Universal Transformer performs better

IN: Subtract -0.01 from -0.055165. OUT: -0.045165

IN: What is the distance between -1296 and 4? OUT: 1300

IN: What is the tens digit of 88? OUT: 8

IN: What is the thousands digit of 1814? OUT: 1

37 of 38

Concluding remarks

  • Use relative positional encoding
  • Use shared layers
  • Use OOD validation set
  • Be careful with early stopping
  • Embedding scaling is important

operation 1

operation 2

...

?

38 of 38

Thank you for your attention!

�Please stay tuned for our future work on �systematic generalization with Transformers.

More to come soon...

/robertcsordas/transformer_generalization