1 of 53

PROGRESS MEASURES FOR GROKKING VIA

MECHANISTIC INTERPRETABILITY

Yorguin-José Mantilla-Ramos and Yousef Kotp

2 of 53

The overall problem…

3 of 53

Abrupt Emergent Behaviours

  • Emergent behavior in AI: new capabilities arise from scaling up (size, data, training time).�
  • Some of this emergent behaviors appear abruptly at some point of the scaled parameters.�
  • How to understand these abrupt changes better?�

[6]

order parameter: quantifies the degree of organization of the system

4 of 53

Methodological Argument of the paper

Progress Measures: ��metrics that precede and are causally linked to the phase transition, and which vary more smoothly”

The concept seems to have originated from this paper, and to be more empirically motivated than the usual dynamical systems’ concepts from physics.

[2]

By finding “progress measures” through mechanistic interpretability we can gain insights on discontinuous emergent behavior.

5 of 53

Methodological Argument

By finding “progress measures” through mechanistic interpretability we can gain insights on discontinuous emergent behavior.

A case study: Grokking in the Modular Addition Task

How to defend

this claim?

Illustrate that on an exemplar of emergent behavior (e.g. grokking), this methodology arrives at meaningful insights.

Grokking serves as a case study of the overall methodological argument of the paper.

6 of 53

The specific problem

How does grokking occur in the modular addition task learned by transformers?

7 of 53

What is Grokking?

The phenomenon where:

“models abruptly transition to a generalizing solution after a large number of training steps, despite initially overfitting” [1]

“... sudden generalization after delayed memorization” [3]

TLDR: generalization after overfitting

Follow-up: How can we attempt to understand grokking?

What motivates the model to do this if it already had the presumably best training accuracy?

→ regularization in the cost function?

�(this as a case study of understanding “emergent phenomena” in ML)

[1]

?

8 of 53

What is modular addition?

How do we know that 17:00h is 5pm

(17) mod 12 = 5

So basically,

A mod B = rem. of A/B

Before that: What is Modular Arithmetic?

⇒ Going around the clock

9 of 53

What is then ( A + B ) mod C ? → Modular Addition

It is possible for humans to find the relation:

How do transformers do it?

Do they implement it in a similar way?

Do they grokk to find this?

Core idea:� �we can compose �modular addition…

17 mod 12 =

( 15 + 2 ) mod 12

15 mod 12 = 3

2 mod 12 = 2

( 3 + 2 ) mod 12 =

5 mod 12 =

5

* this is obviously not the best way to decompose this addition (e.g. 12 + 5 is better)...

10 of 53

What is the model used in this paper?

Usual schematic of a transformer

[4]

Too complex to study

Let’s simplify it

[5]

11 of 53

The simplified transformer circuit

[5]

Residual Stream

sum of the output of

all the previous layers

and the original embedding

includes token and positional embedding (t is the vector of tokens)

Each head extracts relevant information for a token based on the context from other tokens via self-attention

(may be repeated)

refines the embedding space of the model without interaction between tokens

maps the embedding space to a probability distribution over the vocabulary

12 of 53

The simplified transformer circuit

[5]

Residual Stream

sum of the output of

all the previous layers

and the original embedding

includes token and positional embedding (t is the vector of tokens)

Each head extracts relevant information for a token based on the context from other tokens via self-attention

(may be repeated)

refines the embedding space of the model without interaction between tokens

maps the embedding space to a probability distribution over the vocabulary

13 of 53

The simplified transformer circuit

[5]

Residual Stream

sum of the output of

all the previous layers

and the original embedding

includes token and positional embedding (t is the vector of tokens)

Each head extracts relevant information for a token based on the context from other tokens via self-attention

(may be repeated)

refines the embedding space of the model without interaction between tokens

maps the embedding space to a probability distribution over the vocabulary

14 of 53

The simplified transformer circuit

[5]

Residual Stream

sum of the output of

all the previous layers

and the original embedding

includes token and positional embedding (t is the vector of tokens)

Each head extracts relevant information for a token based on the context from other tokens via self-attention

(may be repeated)

refines the embedding space of the model without interaction between tokens

maps the embedding space to a probability distribution over the vocabulary

15 of 53

The simplified transformer circuit

[5]

Residual Stream

sum of the output of

all the previous layers

and the original embedding

includes token and positional embedding (t is the vector of tokens)

Each head extracts relevant information for a token based on the context from other tokens via self-attention

(may be repeated)

refines the embedding space of the model without interaction between tokens

maps the embedding space to a probability distribution over the vocabulary.

In the paper is also referred as WL

16 of 53

The simplified transformer circuit

[5]

Residual Stream

sum of the output of

all the previous layers

and the original embedding

includes token and positional embedding (t is the vector of tokens)

Each head extracts relevant information for a token based on the context from other tokens via self-attention

(may be repeated)

refines the embedding space of the model without interaction between tokens

maps the embedding space to a probability distribution over the vocabulary.

In the paper is also referred as WL

17 of 53

The Specific Model and Task

[5,7,113]

[a,b, P ]

“(a+b) mod P =”

Embedding of dim. 128

(5 and 7 are used as indexes of a token-to-embedding vector matrix)

P is held constant,

It actually encodes “=”

4 heads

of dim. 32

single layer, 512 units

Vocabulary of size P

We get the logits of the final token (always the “=”) to get the predicted answer

Task

get (a+b) mod P,

where P is prime

e.g. P=113 and constant.

Data Split :

30% train

70% test

From all known possible pairs up to P-1.

18 of 53

Hypothesis of how the algorithm works

Based on the insights gained through mechinterp

19 of 53

Hypothesis algorithm: Addition through Fourier Multiplication

20 of 53

Hypothesis algorithm: Addition through Fourier Multiplication

Here we dont have the answer, just its cosine…

21 of 53

Hypothesis algorithm: Addition through Fourier Multiplication

Here we dont have the answer, just its cosine…

We can get the “hour” by subtracting all possible hours and seeing in which place (e.g. c=number in the vocab=hour) we land at 0 ⇒ There cos(w(a+b-c)) is 1!

Note: c in yellow here is not the answer, just an arbitrary one.

Correct c �in green

-c

22 of 53

Problem: The cosine is periodic, lands at 1 in many places…

Solution found by the transformer → Use wave interference.

If you test cosines of different frequencies (wk) they all have in common the correct token c as a maximum (1) of the function.

= (a + b - c)

Constructive interference at the answer → big logit there

23 of 53

Ok, so how did they reverse engineer it to arrive at this hypothesis?

24 of 53

Lines of evidence for the Fourier Algorithm

  • Surprising periodicity in:
    • Embedding
    • Attention Heads
    • Neuron Activations of MLP
    • Logits�
  • Many quantities are approximately appropriate trigonometric terms
    • The unembedding matrix WL
    • Outputs of the MLP
    • Logits�
  • Ablations based on the hypothesis algorithm follow expectations

25 of 53

How to provide evidence to the hypothesis algorithm

  • The authors suspect that the model learns a structured, frequency-based representation rather than a purely arbitrary mapping of tokens.�
  • Given that the task is modular addition, which has an inherent cyclic structure, it makes sense that a Fourier-based representation could be useful.�
  • By applying a Fourier transform to the matrix like embedding matrix WE, they can check whether the model primarily represents inputs in terms of a few key frequencies.

26 of 53

Surprising Periodicity: Embedding Layer

  • They apply a Fourier transform along the input dimension of the embedding matrix WE then computer the L2 norm along each column.�
  • They found 6 key frequencies

27 of 53

Surprising Periodicity: Neuron-Logits map

  • They applied Fourier analysis for the neuron-logits map too. (WL)�
  • They found a periodic structure too�
  • They found 5 key frequencies

28 of 53

Surprising Periodicity

29 of 53

Surprising Periodicity: Attention Heads and MLP

  • They plot attention and activation for head 0 and neuron 0 for every combination of two inputs.
  • k = 35 for attention and k = 42 for MLP.

30 of 53

Surprising Periodicity: Logits

  • They represent the logits in the 2D Fourier basis over the inputs, then take the L2-norm over the output dimension. These show up as five 2 × 2 blocks.

31 of 53

Trigonometric Identity Learning

  • The assumption is that the functions cos (wk(a + b)), sin (wk(a + b)) are linearly represented in the MLP activations.
  • They showed that WL (the matrix mapping MLP activations to logits) is (approximately) rank 10 and can be well approximated as:
  • Second, note that their model implements the logits for a, b as:

u and v are learned weight vectors*

32 of 53

Trigonometric Identity Learning

  • They check empirically that the terms uK MLP(a, b) and vK MLP(a, b) are approximate multiples of cos (wK(a + b)) and sin (wK(a + b)) (> 90% of variance explained).

33 of 53

Logits Approximation by Weighted Sum

  • They approximated the output logits as: (and fit the coefficients α via ordinary least squares)
  • This approximation explains 95% of the variance in the original logits. �
  • This is surprising, the output logits are a 113 dimensional vector for different 113 x 113 pair (12769), but are well approximated with just the 5 directions

34 of 53

Ablation: Effect of Ablating Frequencies for Logits

  • For final logits in the Fourier space over all the possible 113 x 113 pairs, they tried ablating all the frequencies
  • Ablating everything except key frequencies improved results!

35 of 53

Ablation: Direction of WL

  • We found that WL is well approximated by the 10 directions corresponding to the cosine and sine of key frequencies. �
  • If we project the MLP activations to these 10 directions, loss decreases 50%.�
  • If we instead projected the MLP activations onto the nullspace of these 10 directions, loss increases to become worse than uniform (random).�
  • This suggests that the network achieves low loss using these and only these 10 directions.

36 of 53

Understanding Grokking using Progress Measures

37 of 53

Grokking Results

  • Loss/accuracy results on modular addition (5 random seeds)

38 of 53

1- Restricted loss

  • Progress measure inspired by ablation experiments�
  • They ablate every non-key frequency�
  • They perform a 2D DFT on the logits to write them as a linear combination of waves in a and b, and set all terms besides the constant term and the 20 terms corresponding to cos(wK(a + b)) and sin(wK(a + b)) for the five key frequencies to 0.

39 of 53

2- Excluded loss

  • Progress measure inspired by ablation experiments�
  • They ablate all key frequencies�
  • Intuition:

“The idea is that the memorizing solution should be spread out in the Fourier domain, so that ablating a few directions will leave it mostly unaffected, while the generalizing solution will be hurt significantly”

40 of 53

3- Gini Coefficient of L2 Norm of WE and WL

  • They measured the Gini coefficient of the norms of the Fourier components of WE and WL (embedding layer and neuron-logits map)�
  • This measures the sparsity of WE and WL in the Fourier basis�
  • Intuition:�More memorization → Less sparsity�More generalization → More sparsity

41 of 53

4- L2 Norm of Weights

  • They argue that L2 norm of weights could also measure the progress in grokking�
  • Since weight decay should validation loss down once the train loss is near zero.

42 of 53

Phases of Grokking

  • They discovered that grokking happens in three phases�
    • M: Memorization�
    • CF: Circuit Formation�
    • C: Cleanup

M CF C SS

Steady State

43 of 53

Progress Measures of Grokking: Memorization

  • Excluded loss decline
  • Restricted loss remain high
  • Train loss decline
  • Test remain high
  • Gini coefficient stay relatively flat
  • L2 relatively stay high

44 of 53

Progress Measures of Grokking: Circuit Formation

This suggests that the model’s

behavior on the train set transitions smoothly from the memorizing solution to the Fourier multiplication algorithm

  • Excluded loss rises
  • Restricted loss starts to fall
  • Train loss stay flat.
  • Test loss stay flat.
  • Gini coefficient start to rise
  • L2 norm falls

45 of 53

Progress Measure of Grokking: Cleanup

  • Excluded loss plateaus�
  • Restricted loss continues to drop�
  • Train loss remain at plateau �
  • Test loss suddenly drops�
  • Gini sharply rise�
  • L2 sharply drops

46 of 53

L2 norm Analysis

  • Each grokking phase transition correspond to inflection point in L2 norm�
  • This tells us that L2 norm is the most important component of grokking�
  • Model with smaller L2 regularization takes longer to grok!�
  • Their network do not grok on the modular arithmetic task without weight decay or some other form of regularization �→ regularization is required.

47 of 53

Data fraction Analysis

  • Amount of data affects grokking�
  • Models trained on smaller fractions of data take longer to grok�
  • When networks are provided with enough data, there is no longer a gap between the train and test losses�
    • Both decline sharply some number of epochs into training

48 of 53

Outro

49 of 53

Conclusion

  • In this work, they studied the behavior of small transformers on a simple algorithmic task, solved with a single circuit.
    • On the other hand, larger models use larger, more numerous circuits to solve significantly harder tasks.
    • Their method does not generalize to other models�
  • Methods for automating the analysis and finding task-independent progress measures seem necessary to scale to other larger models.�
  • We lack a general notion of criticality that would allow us to predict when the phase transition will happen
    • We only see that progress measures increase relatively smoothly before the phase transition

50 of 53

Take-home messages

  • We can use mechanistic interpretability to gain intuition about emergent behaviors in models, in particular by trying to design progress measures based on our insights from mechinterp.�
  • Regularization seems to be necessary for grokking, and its strength modulates grokking speed.�
  • Models trained on smaller fractions of data take longer to grok�
  • There seem to be 3 clear distinct grokking phases: memorization, circuit formation and cleanup.

51 of 53

Original Grokking Paper

Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets

Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin @ OpenAI

Vedant Misra @ Google

Will be presented by: Yuxing Tian Zibo Shang

4/21/2025

52 of 53

References

[0] N. Nanda, L. Chan, T. Lieberum, J. Smith, and J. Steinhardt, “Progress measures for grokking via mechanistic interpretability.” arXiv, 2023. doi: 10.48550/ARXIV.2301.05217. Available: https://arxiv.org/abs/2301.05217��[1] A. Power, Y. Burda, H. Edwards, I. Babuschkin, and V. Misra, “Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets.” arXiv, 2022. doi: 10.48550/ARXIV.2201.02177. Available: https://arxiv.org/abs/2201.02177��[2] B. Barak, B. L. Edelman, S. Goel, S. Kakade, E. Malach, and C. Zhang, “Hidden Progress in Deep Learning: SGD Learns Parities Near the Computational Limit.” arXiv, 2022. doi: 10.48550/ARXIV.2207.08799. Available: https://arxiv.org/abs/2207.08799��[3] K. Clauw, S. Stramaglia, and D. Marinazzo, “Information-Theoretic Progress Measures reveal Grokking is an Emergent Phase Transition.” arXiv, 2024. doi: 10.48550/ARXIV.2408.08944. Available: https://arxiv.org/abs/2408.08944��[4] A. Vaswani et al., “Attention Is All You Need.” arXiv, 2017. doi: 10.48550/ARXIV.1706.03762. Available: https://arxiv.org/abs/1706.03762��[5] Nelson Elhage et al., “A mathematical framework for transformer circuits.” Transformer Circuits Thread, 2021. Available: https://transformer-circuits.pub/2021/framework/index.html��[6] E. F. W. Heffern, H. Huelskamp, S. Bahar, and R. F. Inglis, “Phase transitions in biology: from bird flocks to population dynamics,” Proceedings of the Royal Society B: Biological Sciences, vol. 288, no. 1961. The Royal Society, Oct. 20, 2021. doi: 10.1098/rspb.2021.1111. Available: http://dx.doi.org/10.1098/rspb.2021.1111

53 of 53

Q & A