Pretraining Large Language Models
Leandro von Werra
Plan for today
State of LLMs
scaling large and smol
LLM families
closed model APIs
open model weights
fully open model
model weights not available
no access to training data or code
full access to model/code/data
Trends: train longer
Trends: train larger
Trends: more context
Trends: more compute
compute ≈ data x model size
| Dataset (Billion Tokens) | Model size (Billion Parameter) |
GPT 1: | 1-2 | 0.11 |
GPT 2: | 10-20 | 1.4 |
GPT 3: | 300 | 175 |
GPT 4: | 10’000 | 1’800 |
100x
2000x
300x
GPT-4 cost: ~$100M Dollars
Compute:
Trends: more compute
Trends: smol models
Trends: why? Scaling Laws!
Can we extrapolate to the performance of …
Scaling Laws
predictable scaling returns
Scaling laws: Predictable returns
Model size
Compute
Data
Loss
https://arxiv.org/abs/2001.08361
Scaling laws: Compute optimal
Compute
Compute Budget
Too small: loss already flattened out
Optimal: lowest loss at current compute budget
Too large: not yet through steep loss zone
Scaling laws: Downstream performance
https://arxiv.org/abs/2303.08774
Scaling laws: Chinchilla fix
https://arxiv.org/abs/2203.15556
Scaling laws: Chinchilla fix
WAT?!
Llama-3 8B trained on 15T tokens
https://arxiv.org/abs/2203.15556
Scaling laws: Inference
Chinchilla optimal models are only training compute optimal and ignore inference compute
Scaling laws: Harm’s law
Dataset
aka the secret sauce
aka 90% of all the work
Dataset: the secret workhorse of LLMs
https://nonint.com/2023/06/10/the-it-in-ai-models-is-the-dataset/
Dataset: goal of pretraining
Train a general-purpose model → maximal coverage
Requires:
Challenges:
Dataset: where to find data
Dataset: FineWeb
https://hf.co/spaces/HuggingFaceFW/blogpost-fineweb-v1
Dataset: the average web
… is mostly garbage:
If we want a high quality model we need to clean it up!
Dataset: filtering pipeline
Dataset: filtering pipeline
Dataset: general advice
(manually, clustering, tokenizing etc)
Dataset: language filtering
Dataset: quality heuristics
Dataset: quality heuristics
Advantages:
Drawbacks:
Dataset: quality filtering - ML
Given a set of examples of good/bad documents:
(see https://github.com/kpu/kenlm)
→ Filter based on a threshold
Dataset: FineWeb-Edu
Below is an extract from a web page. Evaluate whether the page has a high educational value and could be useful in an educational setting for teaching from primary school to grade school levels using the additive 5-point scoring system described below. Points are accumulated based on the satisfaction of each criterion:
The extract:
<EXAMPLE>.
After examining then extract:
Llama 3 70B
500K samples
Small Transformer
FineWeb-Edu
Annotate
Train
Infer
Dataset: FineWeb-Edu
Dataset: notes on filtering
Taking care of domains specificities
Deterministic vs. stochastic selection
Dataset: deduplication
Fuzzy
Exact
time/memory consumption
counter intuitive results
Dataset: evaluate data quality
Small models trainings: train 1-2B size models on 30GT (chinchilla optimal)�
Dataset: The Stack
Dataset: The Stack v2
Dataset: Cosmopedia
Distributed Training
simple things get complicated
Training: strategy
Compute budget is external constraint
Compute cluster and models size determine training topology
Training: basic training step
Training: anatomy of memory
Training: activation recomputation
Training: activation recomputation
Sequence length
Selective: store activations of specific operations → 2-3% slowdown
Full: only store activations at layer level → 30% slowdown
Training: gradient accumulation
Split global batch into micro batches to save memory
Now let’s add more GPUs!
Training: Data Parallelism - 1D
Distribute micro batches across GPUs
all_reduce
Training: Overlap Communication +
Computation
https://siboehm.com/articles/22/data-parallel-training
Training: Tensor Parallel - 2D
What if the model still doesn’t fit? Split matrix multiplications:
Training: Tensor Parallel - 2D
What if the model still doesn’t fit? Split matrix multiplications:
Training: Tensor Parallel - 2D
What if the model still doesn’t fit? Split matrix multiplications:
Training: Tensor Parallel - 2D
Which one to use? Let’s look at the feedforward layers:
Training: Tensor Parallel - 2D
Which one to use? Let’s look at the feedforward layers:
We can save two communication steps!
Training: Tensor Parallel - 2D
What about multi-head attention?
Column-parallel ←→ Each worker processes a subset of heads
Training: Sequence Parallel - 3D
Training: Going beyond 1 node
Intraconnect:
NVSwitch: 900 GB/s
Interconnect:
Infiniband: 50 GB/s
Node: 4-8 GPUs
Cluster: thousands of nodes
→ TP generally doesn’t scale beyond one node
Training: Pipeline Parallelism - 3D
Layer 1-4
Idle time bubble
Share layers across GPUs. Naive PP:
Naive PP is very inefficient with GPUs idling most of the time
Training: Pipeline Parallelism - 3D
Microbatches
AFAB: All Forward - All Backward
Training: Pipeline Parallelism - 3D
Microbatches
1F1B: 1 Forward - 1 Backward
Training: Pipeline Parallelism - 3D
Microbatches
Interleaved 1F1B
Training: Context Parallelism - 4D
How do you train with 1M context?
Ring Attention!
GPU: 1
Q1, K1, V1
GPU: 3
Q3, K3, V3
GPU: 2
Q2, K2, V2
GPU: 4
Q4, K4, V4
GPU: 1
Q1
GPU: 3
Q3
GPU: 2
Q2
GPU: 4
Q4
K1, V1
K2, V2
K3, V3
K4, V4
GPU: 1
Q1
GPU: 3
Q3
GPU: 2
Q2
GPU: 4
Q4
K1, V1
K2, V2
K3, V3
K4, V4
GPU: 1
Q1
GPU: 3
Q3
GPU: 2
Q2
GPU: 4
Q4
K1, V1
K2, V2
K3, V3
K4, V4
GPU: 1
Q1, K4, V4
GPU: 3
Q3, K2, V2
GPU: 2
Q2, K1, V1
GPU: 4
Q4, K3, V3
GPU: 1
Q1
GPU: 3
Q3
GPU: 2
Q2
GPU: 4
Q4
K4, V4
K1, V1
K2, V2
K3, V3
GPU: 1
Q1
GPU: 3
Q3
GPU: 2
Q2
GPU: 4
Q4
K4, V4
K1, V1
K2, V2
K3, V3
GPU: 1
Q1
GPU: 3
Q3
GPU: 2
Q2
GPU: 4
Q4
K4, V4
K1, V1
K2, V2
K3, V3
GPU: 1
Q1, K3, V3
GPU: 3
Q3, K1, V1
GPU: 2
Q2, K4, V4
GPU: 4
Q4, K2, V2
GPU: 1
Q1
GPU: 3
Q3
GPU: 2
Q2
GPU: 4
Q4
K3, V3
K4, V4
K1, V1
K2, V2
GPU: 1
Q1
GPU: 3
Q3
GPU: 2
Q2
GPU: 4
Q4
K3, V3
K4, V4
K1, V1
K2, V2
GPU: 1
Q1
GPU: 3
Q3
GPU: 2
Q2
GPU: 4
Q4
K3, V3
K4, V4
K1, V1
K2, V2
GPU: 1
Q1, K2, V2
GPU: 3
Q3, K4, V4
GPU: 2
Q2, K3, V3
GPU: 4
Q4, K1, V1
Training: Context Parallelism - 4D
ZigZag Ring attention: making sure all GPUs do equal work!
Training: 4D parallelism
All 4D parallel approaches are combinable and complimentary:
Training: ZeRO
ZeRO (Zero Redundancy Optimizer):
K=12 for Adam
has 50% more comms
Training: Putting all together
Training: Flash Attention + Fused Kernels
Standard Attention
Training: Flash Attention + Fused Kernels
https://arxiv.org/pdf/2205.14135
Training: Mixed Precision Training
Training: Mixed Precision Training
Recipe for BF16/FP16 mixed precision training:
Speed: Operations in lower precision are faster!
FP8: still experimental but we have some promising approaches
Training: Learning rate schedules
Moving from Cosine to Warmup-Stable-Decay (WSD):
More flexibility, e.g. data stages!
Training: Data Stages
Hugging Face: Tools
Questions?
GitHub/HF Hub/X: lvwerra