Misleading Endpoints
Lessons from LLM Training Dynamics
Angelica Chen
2nd Workshop on High-dimensional Learning Dynamics (HiLD)
ICML 2024
How LLMs are commonly evaluated/analyzed
… evaluation metrics for the final or best checkpoint
… interpretability artifacts for the final or best checkpoint
AlpacaEval
Chatbot Arena
Attention visualizations
2
What does this approach miss?
How the model develops during training, what it learns, and how this affects future model performance.
Models with similar final test metrics may take different paths to get there.
Sudden improvement in in-context learning abilities (In-context Learning and Induction Heads, Olsson et al.)
3
What does this approach miss?
Studying only the endpoints both misses key information about the model and may mislead us into making false conclusions.
In this talk: what we miss (discrete phases of training and phase transitions), and how we may be misled (how analyzing the endpoint of training may mislead us about what the model actually learns)
4
Learning often occurs discontinuously and in discrete phases.
Some examples
Grokking of a 1-layer transformer on a modular addition task.
From "Progress measures for grokking via mechanistic interpretability," Nanda et al.
Distinct memorization and compression phases during BERT-Base pre-training.
From "Sudden Drops in the Loss: Syntax Acquisition, Phase Transitions, and Simplicity Bias in MLMs" Chen et al.
circuit formation
cleanup
memorization
compression
6
Can we learn the latent phases of training?
We extract metrics from the training trajectory,
Delays, Detours, and Forks in the Road: Latent State Models of Training Dynamics, Hu et al. (and presented at HiLD 2023!)
7
Can we learn the latent phases of training?
We extract metrics from the training trajectory, train a Hidden Markov Model (HMM) on the metrics,
Delays, Detours, and Forks in the Road: Latent State Models of Training Dynamics, Hu et al. (and presented at HiLD 2023!)
8
Can we learn the latent phases of training?
We extract metrics from the training trajectory, train a Hidden Markov Model (HMM) on the metrics, and use the HMM to label discrete phases during training.
Delays, Detours, and Forks in the Road: Latent State Models of Training Dynamics, Hu et al. (and presented at HiLD 2023!)
9
Model training is path dependent
Across different random seeds, some models generalize quickly…
10
Model training is path dependent
Across different random seeds, some models generalize quickly… while others take thousands more epochs! But they end with the same validation loss.
11
Model training is path dependent
Modeling the trajectory of training allows us to identify detour states, or states that occur only in trajectories where model generalization is slow.
12
Phases are often bookended by steep phase transitions
internal syntax onset
During BERT-Base pre-training, internal syntax (measured by unlabeled attachment score, UAS) arises abruptly at the start of training.
"Sudden Drops in the Loss: Syntax Acquisition, Phase Transitions, and Simplicity Bias in MLMs" Chen et al.
13
Phases are often bookended by steep phase transitions
capabilities onset
internal syntax onset
"Sudden Drops in the Loss: Syntax Acquisition, Phase Transitions, and Simplicity Bias in MLMs" Chen et al.
This is immediately followed and bookended by another phase transition – the onset of linguistic capabilities.
14
Phases are often bookended by steep phase transitions
"Sudden Drops in the Loss: Syntax Acquisition, Phase Transitions, and Simplicity Bias in MLMs" Chen et al.
These two phase transitions decompose the initial loss drop into two phases – internal syntax acquisition and external linguistic capabilities acquisition.
15
What do phase transitions teach us?
"Sudden Drops in the Loss: Syntax Acquisition, Phase Transitions, and Simplicity Bias in MLMs" Chen et al.
These are phase transitions not just in internal representation and external capabilities, but also in model complexity! Reminiscent of the information bottleneck theory – memorization, then compression.
16
What do phase transitions teach us?
"Sudden Drops in the Loss: Syntax Acquisition, Phase Transitions, and Simplicity Bias in MLMs" Chen et al.
Is this phase transition necessary for learning?
Surprisingly, if we suppress internal syntax, linguistic capabilities decline in the long term but loss drops faster. The model learns an alternative strategy in the absence of internal syntax!
17
Analyzing only the endpoint of training (either theoretical or empirical) leads to misleading conclusions.
Understanding training dynamics can help rectify these misunderstandings.
Background: preference learning algorithms
Give me some recommendations for my day trip to New York City. I particularly like outdoor attractions and would prefer to take public transport whenever possible.
I’m sorry, I cannot assist with this request.
For a one-day itinerary that highlights outdoor attractions and accessibility via public transport, I would recommend taking the 2/3 trains to Grand Army Plaza and having a picnic in Prospect Park, followed by visits to the Brooklyn Botanical Garden and Prospect Park Zoo.
19
Why do these algorithms work?
"Intuitively, the DPO update increases the relative log probability of preferred to dispreferred response"
- "Direct Preference Optimization: Your Language Model is Secretly a Reward Model," Rafailov et al.
"We recommend a simple recipe: … calibrate the model with rank loss…and KL divergence regularization."
- "Calibrating Sequence likelihood Improves Conditional Language Generation," Zhao et al.
Conventional wisdom and past literature suggest that improving ranking accuracy also improves generation.
20
Does the final model actually learn to rank?
Reference models
Preference-tuned models
Not much better than random chance!
X = random chance accuracy
21
What does the theoretical endpoint of training suggest?
Intuition: We can calculate the ranking accuracy of an optimal RLHF or DPO model. We call this the idealized ranking accuracy.
22
The Idealized Ranking Accuracy
The observed ranking accuracies are significantly lower than the idealized ranking accuracies! The results from the empirical endpoint do not match the results predicted by the theoretical endpoint.
23
Understanding DPO Training
But our theorem predicted that ranking accuracy should be high for a perfectly trained model. So what’s happening?
24
Understanding DPO Training
But our theorem predicted that ranking accuracy should be high for a perfectly trained model. So what’s happening?
The portion of training before the model overfits.
25
Understanding DPO Training
But our theorem predicted that ranking accuracy should be high for a perfectly trained model. So what’s happening?
For most of the training data, loss decreases.
26
Understanding DPO Training
But our theorem predicted that ranking accuracy should be high for a perfectly trained model. So what’s happening?
But very few of the incorrectly ranked pairs are being flipped to correct before the point of overfitting (~40% -> ~37%).
27
Understanding DPO Training
But our theorem predicted that ranking accuracy should be high for a perfectly trained model. So what’s happening?
So how is DPO decreasing the loss if not by improving ranking accuracy? By increasing reward margins instead! (Recall that, in DPO, reward accuracy != ranking accuracy.)
28
Understanding DPO Training
Why is it so difficult for DPO to flip rankings?
29
Understanding DPO Training
Why is it so difficult for DPO to flip rankings?
In other words, if the reference model is ill-conditioned (i.e. has poor ranking accuracy), then DPO must decrease the loss to a very small value in order to flip the ranking on a particular example.
30
Understanding DPO Training
In other words, if the reference model is ill-conditioned (i.e. has poor ranking accuracy), then DPO must decrease the loss to a very small value in order to flip the ranking on a particular example.
=c
-log 𝜎(𝛽c)
31
Understanding DPO Training
Since training is usually stopped before training loss has dropped very low (due to overfitting), the rankings of most incorrectly-ranked pairs are not fixed in the final model.
=c
-log 𝜎(𝛽c)
32
Reconciling apparent contradictions about DPO…
33
Takeaways
What are some valuable insights that LLM training dynamics have taught us?
34
Thank you!
Collaborators:
Naomi Saphra
Michael Hu
Sadhika Malladi
Matthew Leavitt
Xinyi Chen
Lily H. Zhang
Quiyi Zhang
Ravid Shwartz-Ziv
Kyunghyun Cho
Rajesh Ranganath
35