1 of 19

Midtraining and Context Parallelism

LLM Reasoning Workshop

Autumn ‘25

Quentin Anthony

E-mail: qubitquentin@gmail.com

https://quentin-anthony.github.io/

1

2 of 19

Overview

  1. Midtraining
  2. Context Parallelism

2

3 of 19

Midtraining: Overview

  • Pretraining
    • Focus on breadth, so your model is ready to finetune / continual pretrain / RLAIF across many domains
    • Breadth requires data quantity, and most domains don’t require a long context to capture, so train with a smaller sequence length (2048 – 8192 tokens common)
      • Q: Why not train with higher sequence length during pretrain? Then we wouldn’t have to extend the context later!
      • A: Fixed pretraining budget, and longer sequences take more time. Since most domains don’t need long context anyway, do a shorter sequence length for more steps and extend later!
  • Midtraining
    • Focus on preparing the model for post-training (SFT/RLAIF/reasoning/etc). Smaller domain spread.
    • Extend context to handle reasoning data
    • Large data distribution shift, so we need a higher LR (to adapt quickly), then rapidly decay (so that we settle into local minima)
      • Replay some data from pretraining phase so that we don’t forget breadth
    • Focus on higher quality data (instruct, code, math, reasoning, synthetic)

3

4 of 19

Midtraining: Learning Rate Schedule

  • Need a high LR since we have data distribution shift. Two options:
    • If you’re already at high LR and use high replay during midtrain, no rewarming needed (left)
    • Otherwise, you need to rewarm and then decay again (right)

4

5 of 19

Midtraining: Long Context

  • Reasoning traces are super long!

5

6 of 19

Why Is Context Parallelism Needed?

  • Lots of applications need models to handle very long sequences of tokens:
    • ≤32k is enough for lot of things:
      • Single-file or small-codebase coding
      • Short summarization (news articles, product reviews, etc)
    • Chats quickly grow!

  • This same need for long-context spans modalities
    • Long video/audio samples, thousands of code lines, massive images, etc

  • Both training and inference need long-context
    • Current standard is context length extension: Train at long-ish context (~4k-8k), then ramp up linearly to target context on a small subset of your data.

6

7 of 19

Aside: FLOP Analysis

  • In order to analyze the compute effects of long sequence training, we should first model the FLOP requirements of transformer models.
  • MLP blocks
    • Assuming a basic MLP with two linears of size (h,4h), (4h, h)
    • For each linear to process each token, we perform 2*4h*h operations (2 from multiply + add)
    • Two linears, so 16h2 for each token. Across L layers and on an input of size b*s, we get 16bslh2 total FLOPs for the forward pass
    • In the backward pass, we need grads with respect to both the inputs and the gradients, so 32bslh2 FLOPs for the backward pass.
  • Attention blocks
    • QKVO Matrices
      • Each of the QKVO matrices are applied per-token, and each matrix is of size (h, h). So per-token we need 2*4*h2 FLOPs (2 from multiply + add)
      • Therefore for L layers across an input of b*s, we have 8bsLh2 forward FLOPs and 16bsLh2 backward FLOPs
    • Attention scores and output (ignoring softmax)
      • To compute the attention matrix, we multiply together two matrices of size (b, s, h) matrices to compute an (b, s, s) output. The FLOP cost of this is 2*b*h*s2, and the attention output (multiplication of V by the scores) is equal-cost, leading to a forward pass FLOPs of 4bhLs2, and backward FLOPs of 8bhLs2
  • Total FLOPs = (12hs2 + 72sh2)bL

7

8 of 19

FLOP Analysis for Long Sequence Lengths

  • Total FLOPs = (12hs2 + 72sh2)bL

  • As the model size increases, the cost ratio decreases
    • In simpler terms, it depends on whether the h2 or s2 dominate the total FLOP count
    • For larger models, most of your FLOPs are spent in MLPs and QKVO matrices anyway, so increasing sequence length is more about memory

8

9 of 19

Ring Attention

  • Recall that attention is the computation of the output matrix in the form:
    • QKV are of size (s, h)
    • O is of size (s, s)
  • Done in two steps:
    • Score matrix: S = QKT
    • Attention matrix: A = softmax(QKT) V
  • Therefore, our memory overhead scales like O(s2)
    • Can’t be optimized away with fused kernels, online softmax, etc

  • Goal: Parallelize so that each GPU gets an equal portion of memory contributed from the sequence
    • In order to be practical, this parallelism needs to minimize total communication, and try to overlap computation with communication

9

10 of 19

Ring Attention

  • Key idea for Ring Attention:
    • Split the Q, K, V matrices equally across GPUs
    • Rotate the K and V blocks between GPUs in a ring
    • Keep the Q blocks persistent on the same GPU
    • Overlap the K, V block communication with computation
    • Keep rotating K and V until every GPU has seen every block

10

11 of 19

Ring Attention: Q

  • Partition Q row-wise into a total number of BQ chunks: Qi
    • Follow the orange in the figure. Each output row only depends on a single query row in Q
  • Each chunk is of size CQ
  • BQ = num_GPUs
  • CQ*BQ = s

  • Each Q needs all of K and V, so let’s handle that next

11

12 of 19

Ring Attention: K, V

  • Imagine chunking both Q and K into Qi and Ki
  • We need to compute softmax over full rows of QKT
  • Key idea: split the softmax sum into an “online softmax” so that we don’t have to store the entire K or V matrices all at once

12

13 of 19

Ring Attention: K, V

  • Partition the computation of a single Qi, Ai chunk into BKV independent sub-parts involving only a single Kj, Vj, j in 1, …, BKV chunk at each step
  • The computation of a single Qi, Ai chunk is:

13

14 of 19

Ring Attention

  • Each device gets one Qi chunk and calculates the inner loop iteratively for each Kj, Vj.

  • Each device at a given step j only needs to keep track of the cumulative sum Ai of shape (h, CQ) and only needs a single Vj, Kj block at a time, along with its own Qi

  • Therefore, each GPU only needs to store an equal portion of QKV in its VRAM

14

15 of 19

Ring Attention: Overlap

  • While GPU1 computes the Q1 (K1V1) block, communicate:
    • [K4V4 from GPU4]
    • [K1V1 to GPU2]

  • We need to set the chunk sizes so that time(compute) = time(comms) for perfect overlap

15

16 of 19

Ring Attention: Memory Analysis

  • p: Num model parameters
  • k: Num bytes/parameter
  • d: GPU Devices
  • s: Sequence length
  • b: Batch size
  • h: Hidden size
  • L: Transformer layers
  • a: Num attention heads
  • c: Ring attention chunk size
  • Qi is persistent and therefore just needs klhc bytes per GPU
  • Since we’re overlapping Ki Vi comms and compute, we need to store both in memory simultaneously
  • Therefore, at each iteration we need 4klhc bytes for Ki Vi
  • Total memory overhead per GPU is therefore 6klhc

  • Each GPU needs to send Ki Vi each iteration, so the per-iteration communication volume is 2klhdc
  • The total communication volume across all iterations is 2klhds

16

17 of 19

Ring Attention: Practical Details

  • Ring attention’s parallelism is orthogonal to algorithms like flash attention (ring splits the online softmax into parallel sequence chunks, flash attention splits the online softmax along the head dimension)
    • Flash attention is commonly used in the inner loop of QiKiVi
  • You can only achieve good overlap with a ring algorithm on a 1-level topology
    • You’ll always be communication bottlenecked at the inter-node links of the ring

17

18 of 19

Ring Attention: Summary and Takeaways

  • When sequences are long, your sequence-mixer FLOPs and memory overhead dominate overall runtime of your model
  • Divide along the context across your hardware, called Context Parallelism
  • A popular version of Context Parallelism is Ring Attention, where attention blocks are accumulated across GPUs in a ring
    • Overlap the point-to-point communication with the computation of each attention block
    • Ring attention is not topology-aware, so going across nodes will be comms-bottlenecked

18

19 of 19

Thank You!

19