1 of 78

10-605 / 10-805

Machine Learning from Large Datasets

2 of 78

Announcements

  • Since Monday 11/3
    • Structured miniproject details are out
    • Unstructured mininproject also kicks off
      • Finalize teams by today
    • HW5 is out
  • Happening later
    • Writing session for HW4: Nov 14
    • Writing session for HW5: Nov 21

3 of 78

Outline and Relation to Prior Lectures

  • Compressing models
    • Distillation
    • Quantization
    • Pruning
    • Approximating Attention by Restricting the Key-Value Cache
      • Recap of Key-Value Cache
        • Prefilling vs Decoding
      • Origin story
        • Long-context Transformers with sparse attention
      • Techniques for pruning a key-value cache
        • Heavy hitters – H2O
        • Attention sinks – StreamingLLM
        • Predictive caching – SnapKV

4 of 78

KEY-VALUE CACHING FOR TRANSFORMERS

5 of 78

A quick recap of KV caching

6 of 78

Masked self-attention is used in the decoder

In masked self-attention, the query for token at position i can only attend to thr keys for tokens at positions 0,1,…,i

This is done with a mask

Prediction for token at i is not affected by “future” tokens at i+1,i+2, …

Recap: masked self-attention

7 of 78

GPT

Recap: masked self-attention is always used in decoder-only Transformers

8 of 78

Observation

input

input

Step 3

Step 4

Third output depends only on the first two tokens

9 of 78

Key-Value Caching

input

input

input

Step 4

Third output depends only on the first two tokens

Fourth output output depends only on the first three tokens

Stuff in green boxes has not changed since step 3

The keys and values can be cached and re-used in forward pass

10 of 78

How important is KV caching?

  • Inference with an LLM has two phases
    • Prefilling
      • process the prompt and context
      • this can be very long!
    • Decoding
      • generate the response

11 of 78

How important is KV caching?

  • Inference with an LLM has two phases
    • Prefilling
      • process the prompt and context
      • this can be very long!
      • caching the system prompt alone is very helpful
    • Decoding
      • generate the response

12 of 78

How important is KV caching?

13 of 78

Prefilling: system prompt for Claude 3.7

14 of 78

Prefilling: system prompt for Claude 3.7

15 of 78

Prefilling: system prompt for Claude 3.7

16 of 78

Prefilling: system prompt for Claude 3.7

17 of 78

Prefilling: system prompt for Claude 3.7

18 of 78

Prefilling: system prompt for Claude 3.7

… about 5 more pages more …

19 of 78

How important is KV caching?

  • Inference with an LLM has two phases
    • Prefilling
      • process the prompt and context
      • this can be very long!
      • caching the system prompt alone is very helpful
      • caching repeated parts of a user’s input can save up to 90% of inference costs
    • Decoding
      • generate the response

20 of 78

A “Chain of Thought” Prompt

one of three “in context” examples

21 of 78

A “Chain of Thought” Prompt

22 of 78

A “Chain of Thought” Prompt

23 of 78

How important is KV caching?

  • Inference with an LLM has two phases
    • Prefilling
      • process the prompt and context
      • this can be very long!
      • caching the system prompt alone is very helpful
      • caching repeated parts of a user’s input can save up to 90% of inference costs
        • this requires knowing the user is likely repeat parts of their queries, usually via API call params
    • Decoding
      • generate the response

24 of 78

How important is KV caching?

  • Inference with an LLM has two phases
    • Prefilling
      • process the prompt and context
      • this can be very long!
      • caching the system prompt alone is very helpful
      • caching repeated parts of a user’s input can save up to 90% of inference costs
        • this requires knowing the user is likely repeat parts of their queries, usually via API call params
    • Decoding
      • generate the response

Time to first token

25 of 78

Prefilling

All the inputs needed in pre-filling are immediately available.

There are opportunities for parallelism

26 of 78

How important is KV caching?

  • Inference with an LLM has two phases
    • Prefilling
      • process the prompt and context
      • this can be very long!
      • caching the system prompt alone is very helpful
      • caching repeated parts of a user’s input can save up to 90% of inference costs
    • Decoding
      • generate the response, word by word
      • this is sequential and uses the “prefilled” cache

27 of 78

GPT

Recap: generation is auto-regressive

28 of 78

How important is KV caching?

  • Inference with an LLM has two phases
    • Prefilling
      • process the prompt and context
      • some ways to parallelize
    • Decoding
      • generate the response, word by word
      • this is sequential and uses the “prefilled” cache
      • for each decoding step you generate new queries and “search” the KV cache (with a dot product)
  • Decoding time: O(Noutput tokens * max(Nkv cache))

29 of 78

How important is KV caching?

  • Example: Llama 7B
    • 7B model weights size = 7B * sizeof(float) 🡺 14Gb memory
    • KV cache size is product of
      • batch size = 1
      • sequence length = 4096
      • number of layers = 32
      • attention open hidden layer size = 4096
      • sizeof(float) = 2
      • 2 (for keys + values)
      • 🡺 2GB
    • KB cache size grows linearly with sequence length

30 of 78

Architecture impacts KV caching

  • Grouped query and multi-query attention!

Recall: we can re-use keys and values in “groups”

31 of 78

SPEEDUPS BASED ON KEY-VALUE CACHING

32 of 78

Why is this a lecture?

  • Attention is all you need….
    • …but you don’t need all the attention
  • Attention that’s quadratic in sequence length doesn’t scale to long sequences
    • So we need to approximate somehow…
    • What attention do we really need?
    • What can we ignore?
  • In decoder-only Transformers manipulating the KV cache is an important way of approximating attention

33 of 78

Approximating quadratic attention with KV caching

  • In key-value caching we restrict attention after training by evicting from the cache
  • This is a crucial method for long inputs
    • or outputs, e.g. reasoning models (O1,…)

  • Big question: what is the eviction policy?
    • It needs to be (1) fast and (2) robust

34 of 78

Outline and Relation to Prior Lectures

  • Compressing models
    • Distillation
    • Quantization
    • Pruning
    • Approximating Attention by Restricting the Key-Value Cache
      • Recap of Key-Value Cache
        • Prefilling vs Decoding
      • Origin story
        • Long-context Transformers with sparse attention
      • Techniques for pruning a key-value cache
        • Heavy hitters – H2O
        • Attention sinks – StreamingLLM
        • Predictive caching – SnapKV

35 of 78

Sparse Attention in Transformers

36 of 78

Approximating quadratic attention

  • An early paper/system: ETC (2020)

37 of 78

Approximating quadratic attention

  • Key idea: split tokens into local and global
  • Attention is global-local, local-global, and local-local with constraints
  • Important: Local-local attention is not a mask but a “gather”

38 of 78

Approximating quadratic attention

  • Cost: O(Nglobal2 + Nglobal Nlocal + k Nlocal ) vs N2
  • ETC was very configurable
    • because in 2020 nobody knew how to restrict attention (except locally with ‘sliding windows’)

39 of 78

Approximating quadratic attention

  • Another early paper/system: BigBird (2020)

40 of 78

Approximating quadratic attention

  • Another early paper/system: BigBird (2020)
  • Three kinds of attention: local (aka sliding window); global; random long-range

shortest paths between nodes in a random Erdos-Renyi graph are about

logn / log np

41 of 78

Approximating quadratic attention

  • Another early paper/system: Reformer (2020)

42 of 78

Approximating quadratic attention

  • Reformer used LSH to bucket token representations by similarity and then did attention mostly inside buckets

43 of 78

Outline and Relation to Prior Lectures

  • Compressing models
    • Distillation
    • Quantization
    • Pruning
    • Approximating Attention by Restricting the Key-Value Cache
      • Recap of Key-Value Cache
        • Prefilling vs Decoding
      • Origin story
        • Long-context Transformers with sparse attention
      • Techniques for pruning a key-value cache
        • Heavy hitters – H2O
        • Attention sinks – StreamingLLM
        • Predictive caching – SnapKV

44 of 78

KEY-VALUE CACHING FOR TRANSFORMERS:�TECHNIQUES

45 of 78

Approximating quadratic attention

  • ETC, BigBird, Reformer, … were pretrained with restricted attention
  • In key-value caching we restrict attention after training by evicting from the cache
  • This is a crucial method for long inputs
    • or outputs, e.g. reasoning models (O1,…)

  • Big question: what is the eviction policy?

46 of 78

Finding “Heavy Hitters” in KV Caches

2023

47 of 78

CM Sketch Structure

  • Each string is mapped to one bucket per row
  • Estimate A[j] by taking mink { CM[k,hk(j)] }
  • Errors are always over-estimates
  • Analysis: d=log 1/𝛿, w=2/𝜺 🡺 error is usually less than 𝜺||A||1

A Quick Intro to Data Stream Algorithmics – CS262

+c

+c

+c

+c

h1(s)

hd(s)

<s, +c>

d=log 1/δ

w = 2/ε

from: Minos Garofalakis

i.e. with prob > 1-𝛿

RECAP

48 of 78

CM Sketch Guarantees

  • [Cormode, Muthukrishnan ‘04] CM sketch guarantees approximation error on point queries less than 𝜺||A||1 in space O(1/𝜺 log 1/𝛿)
  • CM sketches are also accurate for skewed values---i.e., only a few entries s with large A[s]

A Quick Intro to Data Stream Algorithmics – CS262

from: Minos Garofalakis

RECAP

49 of 78

CM Sketch Guarantees

  • CM sketches are also accurate for skewed values---i.e., only a few entries s with large A[s]

Application: finding “Heavy Hitters” in streaming data

A plot of the frequency of each word as a function of its frequency rank for two English language texts: Culpeper's Complete Herbal (1652) and H. G. Wells's The War of the Worlds (1898) in a log-log scale.

62690 the

36043 of

27952 and

25725 to

22000 a

19581 in

10328 that

9969 is

9770 was

8833 for

RECAP

50 of 78

Finding “Heavy Hitters” in KV Caches

  • New observations
    • Attention is naturally sparse
    • I.e, most LLMs give small weight to most of the KV cache

sparsity: fraction with < 1% of max attention score

51 of 78

Finding “Heavy Hitters” in KV Caches

  • New observations
    • Accumulated attention weights are distributed like a power law
    • So: a small number of “heavy hitter” tokens seem to be important

red scatter is total attention summed over all decoding steps

x-axis is “co-occurrence” frequency (same token appears ≥ 2x)

52 of 78

Finding “Heavy Hitters” in KV Caches

  • New observations
    • Masking heavy hitter tokens degrades performance

  • So: make sure you keep them in the KV cache!

53 of 78

Finding “Heavy Hitters” in KV Caches

  • Proposed method
    • Fix the cache size to K
    • Generate tokens i=1, …, …
      • Track score (cumulative attention) of each token in the cache
        • “local H2 statistic” – cumulative attention up to position i
      • When the cache is full, overwrite the token with lowest local H2 score
        • (Attention to recent-past tokens keeps tokens from being evicted too quickly)
      • The “eviction policy
    • Local H2 statistic is close to full oracle H2 statistic that looks at future tokens

    • … and combine with old-school “sliding window” attention

54 of 78

Finding “Heavy Hitters” in KV Caches

Streaming

Full counts

55 of 78

Outline and Relation to Prior Lectures

  • Compressing models
    • Distillation
    • Quantization
    • Pruning
    • Approximating Attention by Restricting the Key-Value Cache
      • Recap of Key-Value Cache
        • Prefilling vs Decoding
      • Origin story
        • Long-context Transformers with sparse attention
      • Techniques for pruning a key-value cache
        • Heavy hitters – H2O
        • Attention sinks – StreamingLLM
        • Predictive caching – SnapKV

56 of 78

Other ways of predicting attention

2024

57 of 78

Other ways of predicting attention

58 of 78

Other ways of predicting attention

59 of 78

Other ways of predicting attention

Why are initial positions getting so much attention?

60 of 78

Other ways of predicting attention

Attention sinks

  • The Softmax operation requires attention scores to sum up to one .. when the current query does not have a strong match … the model still needs to allocate attention somewhere
  • Initial tokens are visible to almost all subsequent tokens .. making them readily trained to serve as attention sinks.

61 of 78

StreamingLLM

  • Proposed method
    • Fix the cache size to K
    • Generate tokens i=1, …, …
      • Keep the first 4 tokens and the most recent K-4

Note: I’m skipping some details involving relative positional encoding

red: stream LLM

orange: sliding window attention

blue: dense attention

62 of 78

StreamingLLM vs H2O

  • Proposed method
    • Fix the cache size to K
    • Generate tokens i=1, …, …
      • Keep the first 4 tokens and the most recent K-4

63 of 78

Outline and Relation to Prior Lectures

  • Compressing models
    • Distillation
    • Quantization
    • Pruning
    • Approximating Attention by Restricting the Key-Value Cache
      • Recap of Key-Value Cache
        • Prefilling vs Decoding
      • Origin story
        • Long-context Transformers with sparse attention
      • Techniques for pruning a key-value cache
        • Heavy hitters – H2O
        • Attention sinks – StreamingLLM
        • Predictive caching – SnapKV

64 of 78

Predictive KV Caching: SnapKV

2024

65 of 78

Predictive KV Caching: SnapKV

  • Data: UltraChat
    • 1.4M sequences, selected dialog turns with prompt > 3k tokens, and response > 512 tokens
  • New observation
    • The attention pattern for prompt tokens is consistent throughout decoding
    • The attention pattern for prompt tokens is predictable from LLM behavior before decoding starts

66 of 78

Predictive KV Caching: SnapKV

  • Example: Attention overlap between decoding and 128-tokens windows from the prompt
  • Each line is a layer
  • There is high overlap with the last window in the prompt

67 of 78

Predictive KV Caching: SnapKV

  • Example: Attn overlap between four 128-token windows in decoded response and last 128-token window in prompt

68 of 78

SnapKV: Details

  • Observation window: the n-th window in the prompt
  • Prefix: first n-1 windows in the prompt
  • SumAttn[head h, prefix position j]: attention weights
    • for head h to position j
    • summed over all positions in observation window
      • (Wobs in paper)
  • p: compression ratio, (e.g., 0.9)
  • Caching strategy: evict the KV cache entries for the bottom p*|Prefix| positions (for each head)
    • Refinement: also keep positions near a top (1-p) position (i.e., max-pool SumAttn locally)

69 of 78

SnapKV: Experiment

  • “Needle-in-a-haystack” problem:
    • a long document with a sentence inserted randomly in the middle (e.g., recommend a coffeehouse in Pittsburgh in the middle of the Lord of the Rings)
    • a question answerable by the inserted sentence

70 of 78

Needle in a Haystack: Background

71 of 78

Needle in a Haystack: SnapKV results

72 of 78

Outline

  • Compressing models
    • Distillation
    • Quantization
    • Pruning
      • Unstructured
      • Semi-Structured (e.g. 2:4 Pruning)
      • Structured
  • Key-value caching
    • Motivation and importance
    • Origin story
      • Long-context Transformers
    • Techniques
      • Heavy hitters – H2O
      • Attention sinks – StreamingLLM
      • Predictive caching – SnapKV, FastGen

73 of 78

Predictive KV Caching: FastGen

2024

74 of 78

Predictive KV Caching: FastGen

  • Caching policies explored in past work
    • Locality (Clocal)
      • hyperparam: rlocal = (#local tokens)/(cache size)
    • Heavy Hitter (Cfrequent)
      • hyperparam: rfrequent = (#HH tokens)/(cache size)
    • Special tokens (Cspecial)
      • Examples: start-of-sequence <s>, …
      • hyperparam: none
    • Punctuation (Cpunct) new
      • Examples: period (.), colon (:), question mark (?), …
      • hyperparam: none
    • No eviction (Cfull)
    • Hybrid strategies: e.g., Cspecial + Clocal
      • union of the two caches

Prior analysis of BERT showed it often attends to separators

Prior work (SepLLM) showed separators + initial tokens are useful for KV-caching

75 of 78

Predictive KV Caching: FastGen

Which works best? It depends….

Different heads at level 20

76 of 78

Predictive KV Caching: FastGen

Which works best? It depends….

Same head different layers

77 of 78

Predictive KV Caching: FastGen

  • Which caching strategy to use?
    • Pick a policy for each head/layer based on the prefilling
    • Use that policy to initialize KV cache, and in decoding

78 of 78

Predictive KV Caching: FastGen Results