1 of 73

10-605 / 10-805

Machine Learning from Large Datasets

2 of 73

Announcements

  • HW 5 status
  • Guest lectures: come if you can!
    • Wed 11/12: John Wieting, Google/Gemini: Contrastive Learning and Retrieval Augmented Generation RAG
    • Mon 11/17: Michiel de Jong, Cursor: Optimizing Transformer Architectures for RAG

3 of 73

Outline

  • Recap of KV caching (so far)
    • Examples of using ML for ML Sys
  • Some recent papers
    • Cross-layer attention
    • Kimi Mooncake LLM serving system
  • Broader recap on model compression
    • tying together the post-break lectures
  • Brief motivation of retrieval augmentation

4 of 73

Outline

  • Recap of KV caching (so far)
    • Examples of using ML for ML Sys
  • Some recent papers
    • Cross-layer attention
    • Kimi Mooncake LLM serving system
  • Broader recap on model compression
    • tying together the post-break lectures
  • Brief motivation of retrieval augmentation

5 of 73

KV CACHING RECAP

6 of 73

Recap: Key-Value Caching

  • In key-value caching we restrict attention after training by evicting from the cache
  • This is a crucial method for long inputs
    • or outputs, e.g. reasoning models (O1,…)

  • Big question: what is the eviction policy?

7 of 73

Recap: “Heavy Hitters” in KV Caches

  • Proposed method
    • Fix the cache size to K
    • Generate tokens i=1, …, …
      • Track score (cumulative attention) of each token in the cache
        • “local H2 statistic” – cumulative attention up to position i
      • When the cache is full, overwrite the token with lowest local H2 score
        • (Attention to recent-past tokens keeps tokens from being evicted too quickly)
      • Called an “eviction policy
    • Local H2 statistic is close to full oracle H2 statistic that looks at future tokens

8 of 73

Recap: “Heavy Hitters” in KV Caches

9 of 73

Recap: StreamingLLM

Why are initial positions getting so much attention?

10 of 73

Recap: StreamingLLM

  • Proposed method
    • Fix the cache size to K
    • Generate tokens i=1, …, …
      • Keep the first 4 tokens and the most recent K-4

red: stream LLM

orange: sliding window attention

blue: dense attention

11 of 73

Recap: SnapKV

  • Data: UltraChat
    • 1.4M sequences, selected dialog turns with prompt > 3k tokens, and response > 512 tokens
  • New observation
    • The attention pattern for prompt tokens is consistent throughout decoding
    • The attention pattern for prompt tokens is predictable from LLM behavior before decoding starts

12 of 73

Recap: SnapKV

  • Example: Attn overlap between decoding and 128-tokens windows from the prompt
  • Each line is a layer
  • There is high overlap with the last window in the prompt

13 of 73

Recap: SnapKV Details

  • Observation window: the n-th window in the prompt
  • Prefix: first n-1 windows in the prompt
  • SumAttn[head h, prefix position j]: attention weights
    • for head h to position j
    • summed over all positions in observation window
      • (Wobs in paper)
  • p: compression ratio, (e.g., 0.9)
  • Caching strategy: evict the KV cache entries for the bottom p*|Prefix| positions (for each head)
    • Refinement: also keep positions near a top (1-p) position (i.e., max-pool SumAttn locally)

14 of 73

Recap: Needle in a Haystack: Background

15 of 73

Recap: Needle in a Haystack: SnapKV

16 of 73

Recap: Predictive KV Caching: FastGen

  • Caching policies explored in past work
    • Locality (Clocal)
      • hyperparam: rlocal = (#local tokens)/(cache size)
    • Heavy Hitter (Cfrequent)
      • hyperparam: rfrequent = (#HH tokens)/(cache size)
    • Special tokens (Cspecial)
      • Examples: start-of-sequence <s>, …
      • hyperparam: none
    • Punctuation (Cpunct)
      • Examples: period (.), colon (:), question mark (?), …
      • hyperparam: none
    • No eviction (Cfull)
    • Hybrid strategies: e.g., Cspecial + Clocal
      • union of the two caches

17 of 73

Recap: Predictive KV Caching: FastGen

Which works best? It depends….

Different heads at level 20

18 of 73

Recap: Predictive KV Caching: FastGen

  • Which caching strategy to use?
    • Pick a policy for each head/layer based on the prefilling
    • Use that policy to initialize KV cache, and in decoding

19 of 73

Outline

  • Recap of KV caching (so far)
    • Examples of using ML for ML Sys
  • Some recent papers
    • Cross-layer attention - simple
    • Kimi Mooncake LLM serving
  • Broader recap on model compression
    • tying together the post-break lectures
  • Brief motivation of retrieval augmentation

20 of 73

RECENT KV-CACHING PAPERS

21 of 73

NeurIPS 2024 – 80 cites

22 of 73

Recap: What architectural changes impact KV caching?

  • Grouped query and multi-query attention!

Recall: we can re-use keys and values in “groups”

23 of 73

Idea: some layers re-use KV cache from a previous layer

Like GQA except grouping is across layers instead of inside a single layer

24 of 73

Good results on a 1B-sized model

Improvements less clear at 3B

25 of 73

Outline

  • Recap of KV caching (so far)
    • Examples of using ML for ML Sys
  • Some recent papers
    • Cross-layer attention - simple
    • Kimi Mooncake LLM serving – not simple
  • Broader recap on model compression
    • tying together the post-break lectures
  • Brief motivation of retrieval augmentation

26 of 73

Feb 2025 – Kimi/Tsinghua – 140 cites

27 of 73

Some takeaways

  • Claim: effective workload is doubled
  • Serving system is centered on KV-cache management
  • Prefill and decoding are different
    • Use different resource pools for each stage
    • Use continuous batching during decoding
      • batch is decoding requests from different users, updated after each token
    • Look for caches with prefix matches in prefilling
      • batch is sequential tokens from same prompt
  • Memory resources are different but all useful
    • CPU DRAM, SSD, RDMA (GPU🡪GPU transfer)
    • Cache is distributed over different memories

rebatch frequently

28 of 73

Inference workflow

29 of 73

MoonCake architecture

TTFT = “time to first token”

TBT = “time between tokens”

SLO = “service level objectives”

MFU = “model FLOP utilization”

DRAM: CPU vs VRAM: GPU

30 of 73

10 prefill nodes in cluster, each has 3M token cache

31 of 73

KV CACHING COMPARED TO….

32 of 73

Outline

  • Recap of KV caching (so far)
    • Examples of using ML for ML Sys
  • Some recent papers
    • Cross-layer attention
    • Kimi Mooncake LLM serving system
  • Broader recap on model compression
    • tying together the post-break lectures
  • Brief motivation of retrieval augmentation

33 of 73

How should we think about these?

  • Quantization
  • Pruning
  • Key-Value Cache Management
    • …which is also pruning of a kind

Heuristic: find a survey paper!

34 of 73

How should we think about these?

TACL 2024

35 of 73

36 of 73

To reduce the cost about the weight update process required by SparseGPT, Wanda (Sun et al., 2024) achieves model sparsity by pruning weights with the smallest magnitudes multiplied by the norm of the corresponding input activations, without the need for retraining or weight updates.

37 of 73

Recap: making a large model smaller

  • Distillation
    • briefly discussed
  • Quantization
  • Pruning
  • Key-Value Cache Management
    • …which is also pruning

Quantization, Pruning, and KV Cache Management are really learning problems … with constraints on the model learned

Optimization with discrete constraints is harder—so in practice people approximate

38 of 73

Recap: making a large model smaller

  • Quantization
  • Pruning
  • Key-Value Cache Management

Old-school ML in practice:

  1. examine the data
  2. select/engineer features
  3. pick a simple learning/optimization method that worked for that task
    • sometimes invent a new optimization method
    • sometimes invent a new representation for models

39 of 73

Recap: making a large model smaller

  • Distillation
    • briefly discussed
  • Quantization
  • Pruning
  • Key-Value Cache Management
    • …which is also pruning

Quantization, Pruning, and KV Cache Management are really learning problems … with constraints on the model learned

Optimization with discrete constraints is harder—you need to find the best way to satisfy the discrete constraints and the best parameters

Work often looks like searching the discrete part of the space manually (or greedily or …)

  1. Manual rules for satisfying constraints
    • better rules based on data analysis
  2. Greedy search
    • better metrics
  3. More complex search
  4. Simple learning

40 of 73

Quantization: Beyond 16 bits

  • Going from fp32 🡪 int8 or int4 requires some tricks:
    • Quantize different parts of the model differently
      • quantize(x; S, Z) = round(x/S + Z)
        • S is “scale”
        • Z is “zero point”
        • S, Z are set for this model slice to get the needed range [⍺,β]
      • do as much as you can in quantized space instead of quantizing and de-quantizing frequently (the FMA trick)
        • complete matmul, conv2D-BatchNorm-ReLU
    • Quantize dynamically
      • set S, Z after you see the activation values
    • Don’t quantize everything
      • eg bias layers aren’t usually quantized

Recap

Optimize parameter values via gradient descent (as usual) subject to constraint:

  • groups of parameters are stored in int8/int4
  • multiple S,Z optimized independently by group
  • groups to be quantized: by manual rules
  1. Optimize weights by training LLM
  2. Optimize quantization next (don’t retrain)

41 of 73

Quantization: int8.llm

  • Feature: any single dimension of a token representation
  • Outlier feature: magnitude >= 6 in 25% of the layers and 6% of “sequence dimensions”
    • sequence dimension: token positions
  • Observation: outliers are
    • probabilistic in smaller models (in some layers and not others)
    • systematic in larger models (most layers)

Hypothesis: outlier features are why int8 stops working!

Recap

42 of 73

Quantization: Beyond 16 bits

  • Going from fp32 🡪 int8 or int4 requires some tricks:
    • Mixed-precision quantization – keep large “outlier” features (which are consistently in fixed locations)

Recap

Optimize parameter values via gradient descent (as usual) subject to constraint:

  • groups of parameters are stored in int8/int4
  • multiple S,Z optimized independently by group
  • groups to be quantized: by manual rules

subset to be quantized is chosen with manually fixed rules (based on activation values)

43 of 73

Quantization: Performance Tradeoffs

2024

Recap

44 of 73

Quantization: Beyond 4 bits

  • Quantization-aware training
      • (de)quantization operations are not differentiable
      • partial solutions
        • approximate the ops
        • use “straight through estimation” (STE) of gradients
      • extreme quantization requires methods to quantize vectors by combining vectors from a “codebook”
        • similar to product quantization
        • e.g. ”additive quantization”
      • learning should ideally jointly optimize the continuous parts (codebook values) and discrete parts (assignment of values to code)
        • this can’t be done with gradient descent only

Recap

  1. Optimize weights by training LLM
  2. Optimize quantization next
  3. Then retrain to jointly optimize weights and quantization strategy

45 of 73

Recap: making a large model smaller

  • Distillation
    • briefly discussed
  • Quantization
  • Pruning
  • Key-Value Cache Management
    • …which is also pruning

simple quant rules

complex rules, learning

“32 bits everywhere”

“16 bits except for bias”

int8() except for outliers, defined as ….

jointly optimize codebook/codes and where to quant while doing gradient updates on weights

46 of 73

Recap: making a large model smaller

  • Distillation
    • briefly discussed
  • Quantization
  • Pruning
  • Key-Value Cache Management
    • …which is also pruning

47 of 73

Unstructured pruning

  • Observations (on CNNs):
    • pruning only occasionally causes major degradation in performance
    • performance recovers quickly with re-training
    • performance only slightly degrades even with lots of sparsity

  • For each layer to prune:
    • Repeat i=1,…,t:
      • Sort weights by |wj|
    • Zero* out weights to reach sparsity si
      • *by adjusting a binary mask
    • Re-train for n steps (also adjusting learning rate)

Recap

Optimize parameter while minimizing NNZ weights by alternating

  1. reducing NNZ weights (greedily)
  2. optimizing NZ weights (gradient)

48 of 73

Unstructured pruning: Wanda

  • Design decisions:
    • Which subset to prune?
    • How much to re-train, and on what data?
    • Do you iterate or not?
      • When do you stop?

Recap

Optimize parameter while minimizing NNZ weights by alternating

  1. reducing NNZ weights (greedily, with weights times activations)
  2. optimizing NZ weights (gradient)

49 of 73

What sparse operations are supported?

  • Nvidia supports 2:4 sparse matrices
    • General version is “N:M”
    • In each contiguous block of 4 weights at least 2 are zero.

Recap

50 of 73

What sparse operations are supported?

  • Nvidia supports 2:4 sparse matrices
    • General version is “N:M”
    • In each contiguous block of 4 weights at least 2 are zero.

Speeds for WANDA using semi-structured N:M pruning

Recap

Optimize parameter while minimizing NNZ weights by alternating

  1. reducing NNZ weights (greedily, with weights times activations, in a comparison group)
  2. optimizing NZ weights (gradient)

51 of 73

Structured Pruning: Shortened Llama

  • Approach
    • For each layer, compute increase in PPL on a calibration set if removed
    • Also use a gradient*weight metric, Taylor+
    • Prune layers one-shot to get target size
    • Retrain with LoRA

Recap

Optimize parameter while reducing # layers:

  1. reducing layers (greedily)
  2. optimizing remaining layers (LoRA)

52 of 73

Structured Pruning: Sheared Llama

  • Some key ideas:
    • Prune to a particular target architecture
      • Optimize a loss that includes PPL on a validation set + loss indicating how masks are mismatched to target architecture
      • Use Lagrange multipliers and soft approximation of the mask in training (slow!)

Recap

53 of 73

Structured Pruning: Sheared Llama

  • Some key ideas:
    • Prune to a particular target architecture
      • Pruning and retraining are done in the same optimization pass
      • Retrain to a ”reference loss” predicted by scaling laws.
      • In pruning/retraining, focus on high-loss slices of the data.

Recap

Use gradients with Lagrange multipliers to jointly optimize perplexity and distance to a valid mask

54 of 73

Recap: making a large model smaller

  • Distillation
    • briefly discussed
  • Quantization
  • Pruning
  • Key-Value Cache Management
    • …which is also pruning

simple optimization

complex optimization

Wanda: greedy weight pruning

shortened Llama: greedy layer pruning + LoRA

Magnitude: greedy weight pruning alternating with gradients

Sheared Llama: jointly optimize weights and continuous approximation to mask.

Important note:

complicated ≠ good!

55 of 73

Recap: making a large model smaller

  • Distillation
    • briefly discussed
  • Quantization
  • Pruning
  • Key-Value Cache Management
    • …which is also pruning
      • … pruning parts of the key-value cache

56 of 73

H2O: “Heavy Hitters” in KV Caches

  • Proposed method
    • Fix the cache size to K
    • Generate tokens i=1, …, …
      • Track score (cumulative attention) of each token in the cache
        • “local H2 statistic” – cumulative attention up to position i
      • When the cache is full, overwrite the token with lowest local H2 score
        • (Attention to recent-past tokens keeps tokens from being evicted too quickly)
      • The“eviction policy
    • Local H2 statistic is close to full oracle H2 statistic that looks at future tokens

Recap

Greedy pruning (eviction) based on running estimate of cumulative attention

57 of 73

StreamingLLM

  • Proposed method
    • Fix the cache size to K
    • Generate tokens i=1, …, …
      • Keep the first 4 tokens and the most recent K-4

red: stream LLM

orange: sliding window attention

blue: dense attention

Recap

Eviction is FIFO with an exception for first few tokens

58 of 73

StreamingLLM and H2O

  • Proposed method
    • Fix the cache size to K
    • Generate tokens i=1, …, …
      • Keep the first 4 tokens and the most recent K-4

Recap

Eviction is H2O rule with an exception for first few tokens

59 of 73

SnapKV: Details

  • Observation window: the n-th window in the prompt
  • Prefix: first n-1 windows in the prompt
  • SumAttn[head h, prefix position j]: attention weights
    • for head h to position j
    • summed over all positions in observation window
      • (Wobs in paper)
  • p: compression ratio, (e.g., 0.9)
  • Caching strategy: evict the KV cache entries for the bottom p*|Prefix| positions (for each head)
    • Refinement: also keep positions near a top (1-p) position (i.e., max-pool SumAttn locally)

Recap

Eviction for a prompt is based on a score learned from the observation window for that prompt

60 of 73

Predictive KV Caching: FastGen

2024

Recap

61 of 73

Predictive KV Caching: FastGen

  • Caching policies explored in past work
    • Locality (Clocal)
      • hyperparam: rlocal = (#local tokens)/(cache size)
    • Heavy Hitter (Cfrequent)
      • hyperparam: rfrequent = (#HH tokens)/(cache size)
    • Special tokens (Cspecial)
      • Examples: start-of-sequence <s>, …
      • hyperparam: none
    • Punctuation (Cpunct)
      • Examples: period (.), colon (:), question mark (?), …
      • hyperparam: none
    • No eviction (Cfull)
    • Hybrid strategies: e.g., Cspecial + Clocal
      • union of the two caches

Prior analysis of BERT showed it often attends to separators

Also SepLLM showed separators + initial tokens are useful for KV-caching

Recap

62 of 73

Predictive KV Caching: FastGen

  • Which caching strategy to use?
    • Pick a policy for each head/layer based on the prefilling
    • Use that policy to initialize KV cache, and in decoding

Recap

Eviction strategy for a prompt is learned based on the strategy’s effectiveness in prefilling

…and head identity and layer is also a feature (i.e., each can have a different strategy)

63 of 73

Recap: making a large model smaller

  • Distillation
    • briefly discussed
  • Quantization
  • Pruning
  • Key-Value Cache Management

simple optimization

complex optimization

+ start token

H2O: score by running accumulated attention

SnapKV: Adaptively pick specific weights based on observation window in prefilling

FastGen: Pick strategy per head and layer based on prefilling.

Score by recency (FIFO)

+ separator

64 of 73

Recap: making a large model smaller

  • Quantization
  • Pruning
  • Key-Value Cache Management

Old-school ML in practice:

  1. examine the data
  2. select/engineer features
  3. pick a simple learning/optimization method that worked for that task
    • sometimes invent a new method

65 of 73

What’s next?

  • Quantization
  • Pruning
  • Key-Value Cache Management
  • Simple combinations:
    • it works to sequentially prune then quantize
    • it works to quantize a KV cache
  • Jointly learning to optimize together?
    • prune, quantize, and restrict attention so that KV cache is more effective?

66 of 73

RETRIEVAL AUGMENTED LLMS: INTRO

67 of 73

Anthropic Claude 3.7

68 of 73

Gemini Advanced 2.0 Flash

69 of 73

Anthropic Claude 3.7

Mostly wrong

70 of 73

Gemini Advanced 2.0 Flash

71 of 73

Gemini Advanced 2.0 Flash

Irrelevant from 2017

https://www.cs.cmu.edu/~wcohen/ has all this information

copy of ICML 2008 info

from 2007

72 of 73

73 of 73

Questions

  • How to retrieve? Can we learn to retrieve better?
  • How to generate? Are there problems here specific to RAG?
  • Do we learn to retrieve jointly with learning to generate, or separately?

Is the source of information reliable?

What if documents contradict each other?

What if information needed is spread across many documents?