1 of 74

10-605 / 10-805

Machine Learning from Large Datasets

2 of 74

Announcements

  • Happening today!
    • Structured miniproject details are out
    • Unstructured mininproject also kicks off
      • Finalize teams
    • HW5 is out
      • HW4: collect your own corpus
      • HW5: pretrain an LLM with it
      • Why?
        • It’s harder for us to evaluate
        • It’s more interesting
          • data curations for LLMs is important
          • you’re making some choices
          • we end up with lots of different GPTs
  • Happening later
    • Writing session for HW4: Nov 14
    • Writing session for HW5: Nov 21

Why? stay tuned!

3 of 74

Outline

  • Review of extremely parallel training for ML
  • Some very parallel training methods for LLMs
    • Branch-Train-Merge (BTM)
    • Branch-Train-Mix (BTX)
    • FlexOLMO
  • Some very parallel approaches to multi-task learning
    • Task vectors

4 of 74

Distributed ML

  • Data parallelism
    • workers get parts (usually disjoint partitions) of the data and copies of the current model
    • workers optimize independently for a while
    • workers communicate and synchronize models
  • Decisions to consider:
    • when do you synchronize?
      • when is it ok to optimize a “stale” set of weights?
    • how do you synchronize?
      • how do you send information?
      • how do you compress or sparsify information?

4

RECAP

5 of 74

Data Parallel ML: Minimal synchronization 1

RECAP

ACL 2010

6 of 74

Data Parallel ML: Minimal synchronization 1

  • ML models they compared:
    • Distributed batch gradient
    • Two minimal-synchronization methods
      • Learn P different perceptron from P different shards with batch gradient descent for each shard
        • Combine predictions by majority vote
        • Combine weight vectors by averaging
  • All methods are comparable in CPU time and accuracy but minimal-synchronization methods use 1000x less network bandwidth
  • Can prove lower variance than one shard but not lower bias

6

RECAP

7 of 74

Comparing synchronization approaches for perceptron training

7

one model on 1/p of the data

model averaging with no synchronization

iterative parameter mixing

RECAP

8 of 74

Data Parallel ML: Minimal synchronization 2

8

NeurIPS 2010

RECAP

9 of 74

Data Parallel ML: Minimal synchronization 2

  • Comparisons
    • The minimal-synchronization method
      • Learn P different logistic regression classifiers from T examples each with stochastic gradient descent in a shard
      • Combine weight vectors by averaging
  • Can make formal bounds on the variance and expectation of the averaged parameter vector
    • relative to a single SGD run over T examples

9

RECAP

10 of 74

Comparing synchronization approaches for SGD and logistic regression

10

100 model average, no synchronization

10 model average, no synchronization

1 model

Note: These are convex optimization problems!

RECAP

11 of 74

11

talk pilfered from 🡪 …..

KDD 2011

RECAP

12 of 74

12

iterative SGD, no mixing

limited memory quasi-Newton

param mixing

alternating least squares

IPM

RECAP

13 of 74

Recap of the recap

  • Sometimes it works reasonably well to
    • train P systems completely independently
    • average (”mix”) the parameters once at the end
  • This isn’t optimal but has advantages
    • You can remove (“forget”) data by removing a model from the mix
    • You can add more data easily by training and mixing one more model

14 of 74

BRANCH-TRAIN-MERGE �

15 of 74

2022

16 of 74

BTM: Key ideas

ELM = expert LM

ELMForest = group of ELMs

𝜃1, 𝜃2, …, 𝜃k

small GPT 𝜃

merge?

17 of 74

BTM: Key ideas

  • Merge option 1: at inference time
    • D: a domain for inference—assume it’s an existing domain 1..n
    • x<t : a text prefix—eval is mostly on perplexity/LM
    • Xt : probability distribution over next token

Domain posterior

Using just top k=3 domains is ok

Preferred approach: exponential moving average of the domain posterior, update every 1000 tokens

combine logits from k LLMs

18 of 74

BTM: Key ideas

  • Merge option 2: parameter averaging
    • preferred approach: weight by domain posterior

Domain posterior

using a single merged LLM)

*parameters of 𝜃j weighted by P(D=j)

combine weights from k LLMs

19 of 74

BTM: Key ideas

ELM = expert LM

ELMForest = group of ELMs

𝜃1, 𝜃2, …, 𝜃k

small GPT 𝜃

merged with parameter averaging

20 of 74

BTM: Experiments with 8 domains

compute-matched perplexity with 8 train/8 eval domains

50% of compute for “seed” model 𝜃

inference-time model merging

21 of 74

BTM: Experiments with 8 domains

compute-matched perplexity with 8 train/8 eval domains

50% of compute for “seed” model 𝜃

parameter mixing model merging

22 of 74

BTM: Training costs

23 of 74

2024

24 of 74

Background: Mixture of Experts (MoE)

combine

with

where in training G(x) = SoftMax

and at inference time

25 of 74

Background: Expert Parallelism

This can be parallelized!

  • Each processor hosts a subset of the experts
  • All-to-all communication before and after the expert layer

Often need to train to “balance” the experts with an extra loss term

26 of 74

Background: Expert Parallelism

In Transformers MoE is used for the FFN layers so tokens are routed to the experts

Examples:

  • Switch Transformers
    • 1T param model from 2022
  • OG ChatGPT (according to rumors)
  • Mixtral models (early 2024)
  • Llama 4 series
  • Kimi K2 - 1T parameter model
  • ..

27 of 74

2024

Branch and Train as in Branch-train-merge

Mix the expert LMs by

  • Making each FFN an expert for the corresponding MoE FFN layer
  • Parameter averaging all other parameters
  • Finetune everything (notice gates have new parameters)

28 of 74

BTX details

  • Llama-2 7B
  • Train experts on three domains
    • Math (201B tokens)
    • Code (210B tokens)
    • Wikipedia (42B tokens)
    • and also keep the old model as a fourth expert
  • Finetune another 80B tokens
  • Evaluate on benchmark tasks (not perplexity)

The experts are different

29 of 74

BTX results

math specialist

“data matched”

* BTX with no parallel training stage

*

30 of 74

BTX results

31 of 74

BTX results

32 of 74

FlexOLMO

August 2025

33 of 74

FlexOLMO

Multiple local datasets Di; one public dataset Dpub; one model Mpub trained on Dpub

For each Di

  • Freeze Mpub and copy FFNs twice to create Mi
  • Train Mi on Di freezing everything: except one copy of the FFNs and the router

34 of 74

FlexOLMO

Multiple local datasets Di; one public dataset Dpub; one model Mpub trained on Dpub

For each Di

  • Freeze Mpub and copy FFNs twice to create Mi
  • Train Mi on Di freezing everything: except one copy of the FFNs and the router

Wr is “router embeddings”

ri is initialized with off-the-shelf embeddings from Di

and then trained pairwise with Mpub

Mixture output

Optionally: create local datasets D’i where each is extracted from Mpub and similar to Di

(These can be small)

Sample uniformly from the “proxy datasets” and Dpub to fine-tune Wr

35 of 74

FlexOLMO

36 of 74

FlexOLMO

37 of 74

FlexOLMO

38 of 74

TASK VECTORS

39 of 74

Recap: skip-gram embeddings (word2vec)

Training data:

positive examples are pairs of words w(t), w(t+j) that co-occur

Training data:

negative examples are samples of pairs of words w(t), w(t+j) that don’t co-occur

You want to train over a very large corpus (100M words+) and hundreds+ dimensions

40 of 74

Recap: Results from word2vec

A number of properties of word2vec were surprising and mysterious! until they were explained by Omer Levy and Yoav Goldberg a couple of years later

41 of 74

2014

42 of 74

2014

Notation for analogies:

e.g.,

Word2vec method is “3CosAdd”:

=

=

i.e., b* (queen) should be:

  • similar to b (king)
  • similar to a* (woman)
  • different from a (man)

43 of 74

Recap: distributional clustering (with LSH)

  • Common task: distributional clustering
    • for a word w, v(w) is sparse vector of words that co-occur with w
    • cluster the v(w)’s

43

…guards at Pentonville prison in North London discovered that an escape attempt…

An American Werewolf in London is to be remade by the son of the original director…

…UK pop up shop on Monmouth Street in London today and on Friday the brand…

v(London): Pentonville, prison, in, North, …. and, on Friday

44 of 74

2014

Key idea: these two vectors are closely related

  • word2vec embedding of w=“London”
  • distributional representation of w=“London”
    • with context words c weighted by “positive pointwise mutual information”

word2vec embeddings are matrix factorization of the PPMI matrix

Background

Q: what if you use the original sparse (unfactored) PPMI word vectors for analogies?

45 of 74

2014

Notation for analogies:

e.g.,

Word2vec method is “3CosAdd”:

=

=

i.e., b* (queen) should be:

  • similar to b (king)
  • similar to a* (woman)
  • different from a (man)

46 of 74

2014

sparse PPI vectors 🡪

word2vec 🡪

Observation: old-school sparse vectors also work for analogies

but not as well as word2vec using 3CosAdd rule

i.e., b* (queen) should be:

  • similar to b (king)
  • similar to a* (woman)
  • different from a (man)

3CosMul

*scale sims to [0,1]

47 of 74

2014

sparse PPI vectors 🡪

word2vec 🡪

sparse PPI vectors 🡪

word2vec 🡪

So: old-school sparse vectors also work for analogies.

48 of 74

Discussion: Analogies

It would be great if vector-based representations did have more modularity: e.g., add ”past tense” to a word sense

49 of 74

 

 

50 of 74

Experiment with task negation

  • 𝜃 = pre-trained GPT2
  • Fine-tune GPT2 on CivilComments datasets
    • tnice: toxicity < 0.2
    • ttoxic: toxicity > 0.8
  • Assess toxicity of 1000 outputs with a classifier

𝜃

𝜃 + 𝜏toxic

𝜃 - 𝜏toxic

𝜃 + 𝜏nice

same magnitude

Subtracting the task vector is a way of “forgetting” how to perform that task!

51 of 74

Experiment with scaled task negation

  • 𝜃 = pre-trained GPT2
  • Fine-tune GPT2 on CivilComments datasets
    • tnice: toxicity < 0.2
    • ttoxic: toxicity > 0.8
  • Assess toxicity of 1000 outputs with a classifier

𝜃

𝜃 + 𝜏toxic

 

𝜃 + 𝜏nice

same magnitude

 

52 of 74

Experiment with scaled task negation

  • 𝜃 = pre-trained CLIPS
  • Fine-tune on 8 different image tasks
  • See if the resulting models forgets that task while maintaining performance on control task (ImageNet classification)
  • Report average over all tasks

𝜃

 

𝜃 + 𝜏

smaller model

larger model

53 of 74

Experiments with task analogies

  • Objective: sentiment analysis on reviews from Yelp
  • Data available
    • unlabeled reviews from Yelp
    • sentiment data from Amazon reviews
    • unlabeled Amazon reviews
  • Approach: build task vectors for
    • 𝜏YelpLM: language modeling on Yelp
    • 𝜏AmazonLM: language modeling on Amazon
    • 𝜏AmazonSentiment: sentiment prediction on Amazon

    • Finally use

λ1𝜏AmazonSentiment + (λ2𝜏YelpLM — λ3𝜏AmazonLM)

54 of 74

Experiments with task analogies

  • Objective: sentiment analysis on reviews from Yelp
  • Data available
    • unlabeled reviews from Yelp Amazon
    • sentiment data from Amazon Yelp reviews
    • unlabeled Amazon Yelp reviews
  • Approach: build task vectors for
    • 𝜏YelpLM: language modeling on Yelp
    • 𝜏AmazonLM: language modeling on Amazon
    • 𝜏YelpSentiment: sentiment prediction on Amazon Yelp

    • Finally use

λ1𝜏YelpSentiment + (λ2𝜏AmazonLM — λ3𝜏YelpLM)

55 of 74

Experiments with task analogies

  • Objective: improve performance on rare slices of data

Image classification tasks with CLIP, slices are based on an image style and class.

Also using λ’s picked by validation sets

56 of 74

Experiments with task analogies

Collect 200 images of kings, queens, women, men.

Fine-tune CLIP models on each category, using 1000 ImageNet classes as negative examples

(class name ”something”)

57 of 74

Experiment: adding many task vectors

  •  

𝜃

𝜃 + 𝜏*

Note the training for the subtasks is embarrassingly parallel

58 of 74

Experiment: adding different task vectors

  • weighted sum of task vectors for 2 vision tasks
  • show accuracy normalized by FT accuracy
      • and compared to zero-shot accuracy

Note the training for the subtasks is embarrassingly parallel

59 of 74

Experiment: adding different task vectors

  •  

Note the training for the subtasks is embarrassingly parallel

60 of 74

 

Why does this work?

Task vectors are mostly orthogonal except when tasks are related

MNIST / SVHN both digits; GTRSB is reading traffic signs

EuroSAT / RESISC45 are sat images

61 of 74

 

Why does this work?

“Weight disentanglement”

62 of 74

2023

Disentanglement error:

Disentanglement error of ⍺1𝜏1 and ⍺2𝜏2 with dist(f,g)=1 if labels different

63 of 74

2023

Theory:

  • close enough to an initial 𝜃, for a large enough network, weight changes produce an approximately linear effect
  • task vectors that stay in the linear regime should combine linearly
  • but fine-tuning can take 𝜏 outside the linear regime
  • but it’s possible to measure if this happened
  • it’s also possible to fine-tune inside the linear area by using a kernel
  • this helps make task arithmetic a little better for larger models

64 of 74

2023

Focus on merging models for different tasks

Question: how do you merge models when there is some “entanglement”? Can you improve on task arithmetic?

65 of 74

Question: how do you merge models when there is some “entanglement”?

  1. only keep the top k% of each task vector’s weights

average acc on 11 task vectors

66 of 74

Question: how do you merge models when there is some “entanglement”?

 

TIES-Merge – TrIm, Elect Sign, and Merge

67 of 74

 

TIES-Merge – TrIm, Elect Sign, and Merge

11 NLP / 8 vision datasets; k=20% and λ=1 w/o validation set

68 of 74

69 of 74

accuracy / FT accuracy merging multiple tasks

70 of 74

71 of 74

LoraHub: Method

  • LoRA-train FLAN-T5-Large on 20 randomly selected tasks from FLAN training set
    • The LoRA matrices for task t are At Bt
    • Combine:

    • Optimize weights w with L1-regularized ppx loss on 5-shot training data using black-box optimization (like Vizier)

Other results:

  • 20 LoRA tasks with lowest avg loss on the BBH examples: 35.4
  • FLAN-T5-XL: 36.5

72 of 74

2024

Post-Hoc Adaptive Tokenwise Gating Over and Ocean of Specialized Adapters (PHATGOOSE)

73 of 74

PHATGOOSE methods

  • LoRA-train FLAN-T5-Large on 166 tasks from FLAN training set
  • Train a sigmoid gate for each LoRA:
    • on the same data for 100 steps
  • At inference time, route tokens to to the top 3 LoRA modules (softmax-weighted)

74 of 74

PHATGOOSE results

  • LoRA-train FLAN-T5-Large on
    • 166 tasks from FLAN training set
    • The training tasks from T0
  • Evaluate zero-shot on novel tasks