1 of 72

10-605 / 10-805

Machine Learning from Large Datasets

2 of 72

A few reminders

  • Cheat sheet must be handwritten
    • not written on an ipad and printed
    • not photographically reduced
  • Wed exam is not cumulative

3 of 72

DEEP LEARNING

4 of 72

The Evolution of Transformers

5 of 72

Topics

  • Why are deep MLPs interesting?
    • They are more expressive
  • Why are they hard to learn?
    • Vanishing gradients: gradients can get exponentially smaller with depth
    • Activation values need to stay in a narrow range or else sigmoid units get “stuck”

6 of 72

6

AI Stats 2010

Histogram of gradients in a 5-layer network for an artificial image recognition task

input

output

7 of 72

7

AI Stats 2010

input

output

8 of 72

Topics

  • What tricks make learning easier/better ?
    • CNNs

    • CNNs using multiple parallel convolutions can are more expressive

9 of 72

Topics

  • What tricks make learning easier/better ?
    • softmax loss
    • reLU and variants instead of sigmoids
    • better heuristics for weight initialization
    • residual/skip connections
    • “highway” networks like LSTMs
      • limited number of “deep” paths which pass through a limited number of kinds of nodes

10 of 72

Topics

  • How can we learn language?
    • LSTMs, GRUs, etc
    • word embeddings
    • using attention to
      • “classify” parts a1, …, aN of a network (e.g. embeddings of words, or words-in-context) as relevant to some second representation B
      • “import” the parts of Ai considered relevant into B (attention)

11 of 72

Topics

  • Utility of sequence-to-sequence learning
    • and encoder-decoder models

sequence classification

translation

named entity recognition

image captioning

seq2seq

Encoder

Decoder

12 of 72

Transformers

13 of 72

BERT – Encoder-Only Transformers

2019

=

14 of 72

Decoder-Only Transformers

GPT-1

15 of 72

Tokenization in BERT / GPT2

  • Idea: reduce the number of tokens by finding meaningful subwords
  • Simplest version of this is Byte Pair Encoding (BPE)
  • BERT used a variant called wordpieces

16 of 72

Tokenization for Vision Transformers

2D positional encoding is not better than 1D positional encoding

17 of 72

LoRA (low rank adaptation) finetuning

Key idea: don’t learn a full d x d matrix of weights, instead learn a low-rank approximation to this matrix!

full finetuning

LoRA

finetuning

18 of 72

HYPERPARAMETER TUNING

19 of 72

Hyperband and early stopping

20 of 72

Hyperparameter Selection

Adaptive selection loop:

  • Build a model of f from past experiments
    • “Surrogate model”
  • Pick the next experiment to do based on the surrogate model
    • Experiment = learning task
  • Repeat…

A simple case:

  • a single discrete hyperparameter with k possible values
  • but f is random (e.g., there’s a seed for the ML algorithm)

Goal: find the best of the k values quickly

Explore the space intelligently

21 of 72

Theory: Bandit Problems in ML

  • Some common sampling strategies
    • 𝜺-greedy: Pick a random arm with probability 𝜺 and the best arm with probability 1- 𝜺
    • Upper confidence bound (UCB): Compute confidence intervals on the expected loss of each arm, and pick the arm with best UCB

22 of 72

Gaussian Process Regression

Hyperparameter configuration x to learner performance f(x) is a regression task

Challenges:

  • Labels are expensive
  • For UCB we’d like to measure uncertainty in the regression function

Gaussian Process Regression is a good match for this

23 of 72

Hyperparameter Selection: Early Stopping

What can we prove about this?

What do we need to assume?

In general we don’t know where a learning curve will end up…

In practice there are often big jumps in discrete eval metrics like error rate, win ratios, …

The losses you train and test on are often fairly smooth … and sometimes there are models for them

24 of 72

Successive Halving Method

If the number of configurations n is large, we may get very noisy results early on and miss the best

If n is too small then we will waste resources on bad arms.

25 of 72

Hyperband

Hyperband runs several rounds of SuccessiveHalving with very different sizes n

26 of 72

EMBARRASSINGLY PARALLEL LEARNING METHODS

interleaved with model compression / distillation

27 of 72

Outline

  • Review of extremely parallel training for ML
    • These are special cases of distributed optimization where communication & synchronization are minimal
  • Some very parallel training methods for LLMs
    • Branch-Train-Merge (BTM)
    • Branch-Train-Mix (BTX)
    • FlexOLMO
  • Some very parallel approaches to multi-task learning
    • Task vectors

28 of 72

BTM: Key ideas

ELM = expert LM

ELMForest = group of ELMs

𝜃1, 𝜃2, …, 𝜃k

small GPT 𝜃

merged with parameter averaging

29 of 72

Background: Expert Parallelism

This can be parallelized!

  • Each processor hosts a subset of the experts
  • All-to-all communication before and after the expert layer

Often need to train to “balance” the experts with an extra loss term

30 of 72

2024

Branch and Train as in Branch-train-merge

Mix the expert LMs by

  • Making each FFN an expert for the corresponding MoE FFN layer
  • Parameter averaging all other parameters
  • Finetune everything (notice gates have new parameters)

31 of 72

BTX results

32 of 72

FlexOLMO

Multiple local datasets Di; one public dataset Dpub; one model Mpub trained on Dpub

For each Di

  • Freeze Mpub and copy FFNs twice to create Mi
  • Train Mi on Di freezing everything: except one copy of the FFNs and the router

33 of 72

FlexOLMO

Multiple local datasets Di; one public dataset Dpub; one model Mpub trained on Dpub

For each Di

  • Freeze Mpub and copy FFNs twice to create Mi
  • Train Mi on Di freezing everything: except one copy of the FFNs and the router

Wr is “router embeddings”

ri is initialized with off-the-shelf embeddings from Di

and then trained pairwise with Mpub

Mixture output

Optionally: create local datasets D’i where each is extracted from Mpub and similar to Di

(These can be small)

Sample uniformly from the “proxy datasets” and Dpub to fine-tune Wr

34 of 72

FlexOLMO

35 of 72

 

36 of 72

Experiments with task analogies

  • Objective: sentiment analysis on reviews from Yelp
  • Data available
    • unlabeled reviews from Yelp
    • sentiment data from Amazon reviews
    • unlabeled Amazon reviews
  • Approach: build task vectors for
    • 𝜏YelpLM: language modeling on Yelp
    • 𝜏AmazonLM: language modeling on Amazon
    • 𝜏AmazonSentiment: sentiment prediction on Amazon

    • Finally use

λ1𝜏AmazonSentiment + (λ2𝜏YelpLM — λ3𝜏AmazonLM)

37 of 72

Experiments with task analogies

Collect 200 images of kings, queens, women, men.

Fine-tune CLIP models on each category, using 1000 ImageNet classes as negative examples

(class name ”something”)

38 of 72

Experiment: adding different task vectors

  • weighted sum of task vectors for 2 vision tasks
  • show accuracy normalized by FT accuracy
      • and compared to zero-shot accuracy

Note the training for the subtasks is embarrassingly parallel

39 of 72

MAKING TRANSFORMER MODELS MORE EFFICIENT

40 of 72

Model compression

  • Distillation
  • Quantization
  • Pruning
  • Key-Value Cache Management
    • …which is also pruning with eviction

41 of 72

Distillation

Typically we distill a model on lots and lots of data only labeled by the teacher (transfer set)

The simple case: teacher predictions are “hard” labels (e.g., training a generative LLM).

  1. Easy to add cross-entropy loss to a set of gold labels.
  2. Moderate temperatures seem to require less transfer data for same accuracy.

42 of 72

Recap: making a large model smaller

  • Reviewed a lot of this on Nov 10
  • Quantization
  • Pruning
  • Key-Value Cache Management
    • …which is also pruning

43 of 72

Recap: making a large model smaller

  • Quantization
  • Pruning
  • Key-Value Cache Management
    • …which is also pruning

simple quant rules

complex rules, learning

“32 bits everywhere”

“16 bits except for bias”

int8() except for outliers, defined as ….

jointly optimize codebook/codes and where to quant while doing gradient updates on weights

44 of 72

Recap: making a large model smaller

  • Quantization
  • Pruning
  • Key-Value Cache Management
    • …which is also pruning

simple optimization

complex optimization

Wanda: greedy weight pruning

shortened Llama: greedy layer pruning + LoRA

Magnitude: greedy weight pruning alternating with gradients

Sheared Llama: jointly optimize weights and continuous approximation to mask.

45 of 72

Recap: making a large model smaller

  • Quantization
  • Pruning
  • Key-Value Cache Management

simple optimization

complex optimization

+ start token

H2O: score by running accumulated attention

SnapKV: Adaptively pick specific weights based on observation window in prefilling

FastGen: Pick strategy per head and layer based on prefilling.

Score by recency (FIFO)

+ separator

Lots of detail but I tried to pick out “key ideas” for the papers

46 of 72

RAG AND CONTRASTIVE LEARNING

47 of 72

Outline

  • Wrap-up / review on Nov 19
  • How to learn retrievers
    • Contrastive learning
    • DPR and Contreiver
  • Using decoder-only LLMs for retrieval
    • query expansion
    • generating encodings
  • Retrieval-augmented generation (RAG)
    • The original RAG paper
    • The fusion-in-decoder (FiD) paper
  • A case study: optimizing inference for FiD
    • Systems: FiDO, LUMEN, GLIMMER
  • Incorporating principles from RAG, FiD in decoder-only LLMS

48 of 72

  • Access new information (not available during model pretraining)
  • Better access “long tail” information that is mentioned infrequently in pre-training documents
  • Needs strong retrieval

49 of 72

Recap: The Dense Passage Retriever (DPR)

  • Process several QA datasets to get triples:
    • (question, gold passage id, gold answer span)
  • Triples give positive question/passage examples (qi, pi) encoded in d=768 dimensions by pre-trained BERT
  • Q is mini-batch of B questions, P is B passages, S = QPT is (B x B) matrix of similarity scores
    • score is positive for qi, pj iff i=j
      • in-batch negatives” trick
    • also add as additional hard negatives
      • 1-2 passages with high TFIDF score to q which do not contain the answer span

50 of 72

Recap: Contriever vs DPR

  • Same NT-Xent contrastive loss
  • One BERT encoder instead of two
  • Data augmentation tricks used
    • Independent cropping
    • Random token deletion
  • Negatives were from
    • in-batch negatives
    • momentum contrast (MoCo)
      • MoCo: dynamic dictionary queue and averaging last two models to find hard negatives

51 of 72

Recap: The OG RAG paper

returns top K docs

52 of 72

Recap: Discussion of RAG

  • RAG is now generic term for using retrieval as preprocessing step for LLMs
    • Simplest case is fixed retriever, and appending all docs to the question to form LLM input
    • No custom learner / model / model-combination!
    • Simple case improves immediately whenever
      • LLM improves
      • retriever improves

    • Appending docs scales poorly when many documents are used
      • but there are some tricks to improve that 🡺 FiD

53 of 72

Encoder-only vs decoder-only Transformers to Retrieval

  • DPR, Contreiver, … are encoder-only models
    • based on BERT
  • Encoding documents for retrieval with decode—only models is trickier

54 of 72

Hypothetical Document Embedding (HyDE)

  • Details
    • prompt model (InstructGPT) to convert question q to a query document d

    • sample N documents d1, …, dN, by generation with temperature T
    • query vector for Contriever is average of embeddings for q and d1, …, dN

55 of 72

ExpandR

  • Start out like HYDE
    • Prompted query expansion
    • Followed by dense retrieval
  • Then learn to improve both modules

    • The retriever: fix expander + contrastive learning!

loosely interpreted

56 of 72

ExpandR

  • Start out like HYDE
    • Prompted query expansion
    • Followed by dense retrieval
  • Then learn to improve both modules

    • The expander: an RL method called direct preference optimization (DPO)

57 of 72

Decoder-only models as encoders?

2023

2025

PromptEOL

Echo embeddings

58 of 72

Recap: Fusion in Decoder (FiD)

  • In open-book QA, often
    • Questions q and answer a are short: k=O(10)
    • Passages are longer: m=O(100)
  • Retrieving and appending N passages:
    • Encoder = O(Nm + k)2
    • Decoder = O(k * (Nm + k))
  • Trick:
    • cross-encode q with each passage separately: O(N(k+m)2)
    • decode, attending to all: O(k * N(m + k))

O(N2 m2) ignoring k

O(N m)

O(N m2)

O(N m)

quadratic in N 🡪 linear in N

can afford to retrieve more docs

we lose cross-attention between tokens in different passages

59 of 72

Recap: FiD Optimized (FiDO)

2023

60 of 72

Recap: Fusion in Decoder (FiD)

FLOP analysis: Encoder is 6x as expensive as decoder!

…but at inference time decoding is slowest. Why?

What’s the fix?

Predicted by counting FLOPS for all the matmuls and assuming nt << ns and nt << d

  • nt tokens in decoder output
  • ns tokens in all encoder input

Predicted by memory/FLOPs:

Multilayer perceptron

self-attention

cross-attention

61 of 72

Recap: Fusion in Decoder Optimized (FiDO)

  • Speeding up performance:
    • Layer sparse attention (LSA): perform cross attention only on every K-th layer (K=6)
      • Load fewer key-value entries from encoder on avg
    • Multi-query attention (MQ)
      • Creates fewer key-value entries
  • Compensate for performance loss:
    • Scale the decoder up!
      • Use T5-XL decoder with T5-Base encoder, etc

Decoder

Encoder

Decoder

62 of 72

Multi-query and Grouped-Query Attention

Grouped query attention: don’t require that every attention head have its own keys, values, and queries – instead re-use keys and values in “groups”

63 of 72

Recap: Fusion in Decoder Optimized (FiDO)

64 of 72

Recap: LUMEN: FiD with caching

The first N-K layers of the FiD encoder

The last K layers of the FiD encoder

FiD decoder

Passages are encoded and stored off-line for every document

65 of 72

Main ideas in FiD and extensions

  • In encoder
    • cross-encoder q+p1, q+p2, … separately
      • restricting cross-attention
  • In decoder
    • attend to everything as you generate
  • Extensions
    • FiDO: optimize performance to avoid decoder bottlenecks
    • LUMEN: pre-compute encoder outputs
    • GLIMMER: add reranking

FlashAttention (2023) and FlexAttention (2024)—also improve decoder bottlenecks

Parallel Context Windows (PCW) - 2023

TurboRAG, Blockwise Sparse Attention - 2024

Dynamic Blockwise Sparse Attention - 2025

Analog for decoder-only LLMs

66 of 72

2023

  • Context tokens: retrieved document(s) for RAG, in-context examples, …
  • Task tokens: question for RAG, test input for ICL, …
    • I believe output tokens are also task tokens

Key idea: cross-attend within a context window, and cross-attend between task tokens and all context windows.

Very similar to FiD

  • if question/answer are task tokens
  • no post-training used

67 of 72

2025

Key idea: same as Block-Attention except

  • experiments are on ICL instead of RAG
  • cache the independently-produced KV pairs of the documents
  • evaluate addition of KV pairs computed for retrieved documents
    • retrieve ICL examples with BM25
  • avoid fine-tuning
    • clever use of StreamingLLM trick of “attention sink”
    • don’t mess with position encodings

68 of 72

KV Retrieval vs KV Cache Eviction

  • KV Cache Eviction:
    • get the best (more useful, smallest) KV cache by starting with a big one and making it smaller with evictions
  • PCW, DBSA, TurboRAG
    • get the best (more useful, smallest) KV cache by starting with a small one and making it larger with retrievals
    • some cross-attention never computed

69 of 72

PPI: Statistically Unbiased AI Judges

  • One solution: build an AI system to perform evaluation!
  • … but is this a good idea?

Arguments for:

  • much more data can be collected faster
  • my project is late / over budget
  • human labels are noisy
  • LLM-as-judge biases seem to be minor

  • A small amount of human-labeled data can be used to correct for biases in a judge/autorater
  • statistical inference 🡺 prediction-powered inference

70 of 72

Bayesian PPI

Difference estimate

(old)

Bayesian version

71 of 72

Bayesian PPI: QA evals

C I width

human labels: 3k AR labels

72 of 72

Final messages

  • Exam 2 is similar to exam 1
    • Different content (not cumulative)
    • Similar mix of multichoice, true false, short answer questions