1 of 138

Lecture 8: Language Model

Berkeley CS294-158 Deep Unsupervised Learning

Spring 2024

Hao Liu

2 of 138

Outline

1

  • Methodologies behind large language models
    • Basics
    • Bottlenecks and solutions
    • Scaling laws
    • Capabilities
  • How to train your large language models
    • Sharding

3 of 138

Successes of machine and deep learning

2

  • Language model is everywhere
  • Secret sauce: enormous compute power

Chatbot

Copilot

Sora

TPU datacenter

4 of 138

Language model

3

  • A language model is a probability distribution over sequences:
    • Likelihood distribution
    • Sequence of tokens
    • Parameters:

5 of 138

Language model

4

  • Typical language models are autoregressive:

  • Model is parameterized by a transformer

6 of 138

Autoregressive

5

  • Autoregressive factorization
    • Next token prediction
    • Predicting future
  • Every bit becomes supervision
  • Natural for conversational AI

7 of 138

Modeling likelihood

6

  • Maximum likelihood: make observed data likely under the model

  • D: 1.5 trillions of tokens in Llama
  • : 7 to 70 billions of parameters

8 of 138

Learning

7

  • Pretraining: large-scale unsupervised learning
  • Finetuning: specializing the model for specific tasks
  • Learning from feedback: improving the model with feedback on its outputs, such as human feedback and debug message

9 of 138

Why unsupervised learning?

8

  • Much more data available
  • Much better generalization

10 of 138

Scaling compute

9

  • We want to model data distribution better by adding more compute.
  • The biggest lesson that can be read from 70 years of AI research is that general methods that leverage computation are ultimately the most effective, and by a large margin.” The Bitter Lesson, Richard Sutton 2019

11 of 138

Compute

10

  • Compute is forward and backward passes using our model on token sequences.
  • Mostly many matrix multiplications (matmul)
    • Unit: floating point operations (FLOPs)
  • Compute cost: C = 6ND(1+s/6d)
    • N: parameters, D: data size, s: context size, d: hidden dimension
  • Adding compute by using more tokens, larger context, larger model

12 of 138

Token

11

  • Byte-based: most general but too long
  • Character-based: each word requires too many tokens
  • Word-based: `dog’ and `dogs’ should share meaning
  • Subword-based:
    • Byte-pair encoding: replace top appearing pair with a new token
    • Repeat until reach a given vocab size

13 of 138

Train LSTM with more compute

12

  • LSTM learns a sentiment neuron after training to predict the next word on a large amount of Amazon reviews.

Radford, Alec, Rafal Jozefowicz, and Ilya Sutskever. "Learning to generate reviews and discovering sentiment.” arXiv 2017.

14 of 138

Train LSTM with more compute

13

  • Visualizing the value of the sentiment cell as it processes six randomly selected high contrast IMDB reviews. Red indicates negative sentiment while green indicates positive sentiment. Best seen in color.

Radford, Alec, Rafal Jozefowicz, and Ilya Sutskever. "Learning to generate reviews and discovering sentiment.” arXiv 2017.

15 of 138

Train LSTM with more compute

14

  • LSTM learns a sentiment neuron after training to predict the next word on a large amount of Amazon reviews.

Radford, Alec, Rafal Jozefowicz, and Ilya Sutskever. "Learning to generate reviews and discovering sentiment.” arXiv 2017.

16 of 138

Scaling LSTM is difficult

15

  • Attention allows scalably improving models. Transformers asymptotically outperform LSTMs due to improved use of context.

Scaling laws for neural language models. Kaplan, J., McCandlish, S., Henighan, T., Brown, T. B., Chess, B., Child, R., ... & Amodei, D. (2020)

“Better usage of context”

“Better model performance”

17 of 138

Transformer architecture

16

  • Attention allows attending to past tokens without forgetting.
  • Big FFNs, highly scalable

Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).

18 of 138

Pretraining objective

17

  • Full autoregressive: GPT, Llama
  • Prefix autoregressive: T5
  • Masked: BERT, ELMO

Wang, Thomas, et al. "What language model architecture and pretraining objective works best for zero-shot generalization?.” (2022)

19 of 138

Pretraining objective

18

  • T5, BERT, GPT. All with different masking
  • Causal decoder (GPT) used most frequently

Wang, Thomas, et al. "What language model architecture and pretraining objective works best for zero-shot generalization?.” (2022)

20 of 138

Model architecture

19

  • Upstream Negative Log-Perplexity of vanilla Transformer compared to other models.
  • Transformer outperforms other models.

Tay, Y., Dehghani, M., Abnar, S., Chung, H. W., Fedus, W., Rao, J., ... & Metzler, D. Scaling laws vs model architectures: How does inductive bias influence scaling? (2022).

21 of 138

Model architecture

20

  • Downstream accuracy of vanilla Transformer compared to other models.
  • Transformer outperforms other models.

Tay, Y., Dehghani, M., Abnar, S., Chung, H. W., Fedus, W., Rao, J., ... & Metzler, D. Scaling laws vs model architectures: How does inductive bias influence scaling? (2022).

22 of 138

Compute cost

21

  • Parameter count is N, token count is D, what is the compute cost?
    • C = 6ND(1+s/6d): Compute increases with more data and larger model
    • LLaMA (s << 6d): C = 6ND = 6 ∗ 7 billion ∗ 2 trillion = 8.4 × 10^22 FLOPs

  • Shall I allocate more compute to data or model?

"Scaling laws for neural language models.” (2020)

23 of 138

Optimal token and parameter

22

  • Empirical performance has a power-law relationship with each individual factor.
    • Train model of different sizes and number of tokens (light blue lines)
    • Pick minimal loss (black line)
    • Run linear regression on log – log
  • Follow power-law:

"Scaling laws for neural language models.” (2020)

24 of 138

Allocate compute

23

  • Empirical performance:
  • We have a + b = 1 because C = 6ND

Allocate more compute to parameters (a) or tokens (b)?

25 of 138

Allocate compute

24

  • Coefficients of model scaling and data scaling vary with training data distribution

"DeepSeek LLM Scaling Open-Source Language Models with Longtermism.” (2024)

  • OpenAI (2020) gives more compute to parameters
  • DeepMind (2022) gives more compute to tokens
  • Best practice: train different small models and data sizes, and fit the constants yourself.

26 of 138

Chinchilla scaling

25

  • Chinchilla scaling law: a = 0.49, b=0.51
  • Implication: more compute efficient than other models

"Training Compute-Optimal Large Language Models." (2022).

27 of 138

Chinchilla scaling

26

  • Chinchilla outperforms Gropher significantly by allocating compute better

"Training Compute-Optimal Large Language Models." (2022).

28 of 138

Determine tokens

27

  • Estimated optimal training tokens: about 20 times number of parameters
  • It’s best to double the estimate: tokens = 40x parameters

"Training Compute-Optimal Large Language Models." (2022).

29 of 138

Inference optimal

28

  • Train small model and more tokens to achieve better inference
    • E.g. Llama 7b is trained on 7 times of Chinchilla optimal

Touvron, Hugo, et al. "Llama: Open and efficient foundation language models." (2023).

30 of 138

Loss predicts performance

29

  • Loss determines downstream performance. Both large and small models have the same performance if they have the same loss.

Touvron, Hugo, et al. "Llama: Open and efficient foundation language models." (2023).

31 of 138

Scaling law prediction

30

  • Scaling law can predict performance at larger scale
    • Predicting performance of larger models and larger datasets

"DeepSeek LLM Scaling Open-Source Language Models with Longtermism.” (2024)

32 of 138

Large context

31

  • Major bottleneck: cannot even fit complex, long sequences into Transformers. Needed for agent, world modeling, codebase, and hyperlinked web.

Genome

Agent

Codebase

Hyperlinked web

World

33 of 138

Blockwise parallel transformers

32

  • Reorganize the computation of attention and feedforward.
  • Exact attention and feedforward.

Liu, H., Abbeel, P. “Blockwise Parallel Transformer for Large Context Models”. NeurIPS 2023 Spotlight.

N layers

Outer loop over q

Inner loop on KV

34 of 138

Analysis of memory cost

33

  • Standard attention + standard FFN
    • Peak of attention: O(s**2)
    • Peak of FFN: 8bsh
  • Memory efficient attention + standard FFN
    • Peak of attention: max(4bch, 2bsh) = 2bsh
    • Peak of FFN: 8bsh
  • BPT
    • Peak of attention: 2bsh
    • Peak of FFN: max(8bch, 2bsh) = 2bsh because s >> c

4x times smaller peak activation memory

O(s**2)

8bsh

2bsh

Liu, H., Abbeel, P. “Blockwise Parallel Transformer for Large Context Models”. NeurIPS 2023 Spotlight.

35 of 138

Evaluation

34

  • Four times memory saving thanks to blockwise computation
  • Faster speed thanks to fusion opportunities of attention and FFN

Liu, H., Abbeel, P. “Blockwise Parallel Transformer for Large Context Models”. NeurIPS 2023 Spotlight.

Four times longer context than FlashAttention

36 of 138

Generally applicable

35

Blockwise Transformers allows 16x memory saving without overhead

Gemma: Open Models Based on Gemini Research and Technology

https://blog.google/technology/developers/gemma-open-models/

16x times expanded MLP hidden dimension

Liu, H., Abbeel, P. “Blockwise Parallel Transformer for Large Context Models”. NeurIPS 2023 Spotlight.

37 of 138

Still cannot do million-length sequence

36

  • Memory cost (2bsh) scales linearly with sequence length s.
  • Chip memory cannot scale arbitrarily, and we are already pushing against physics limitations.
  • Using multiple GPUs doesn’t help because attention requires pairwise interactions.

38 of 138

Extension to RingAttention

37

  • Spreading Blockwise Transformers computation graph in a ring of devices.
  • Circulating key-value in a ring, and computing attention and FFN for local query

Key-value loop overlaps communication / computation

query loop is distributed across devices

Liu, H., Abbeel, P. “Blockwise Parallel Transformer for Large Context Models”. NeurIPS 2023 Spotlight.

Liu, H., Zaharia, M., Abbeel, P. “Ring Attention with Blockwise Transformers for Near-Infinite Context”. ICLR 2024.

39 of 138

Analysis of arithmetic intensity

38

  • Require block size large enough such that compute_time (blockwise attention + blockwise FFN) >= communication_time(block size of key and value)
  • Require each device to 2x from key + value, 3x from sending next one, receiving previous one, and computing using current one

Liu, H., Abbeel, P. “Blockwise Parallel Transformer for Large Context Models”. NeurIPS 2023 Spotlight.

Liu, H., Zaharia, M., Abbeel, P. “Ring Attention with Blockwise Transformers for Near-Infinite Context”. ICLR 2024.

40 of 138

Evaluation of max sequence length

39

  • RingAttention matches theoretical maximum performance across model and context configurations.

Liu, H., Abbeel, P. “Blockwise Parallel Transformer for Large Context Models”. NeurIPS 2023 Spotlight.

Liu, H., Zaharia, M., Abbeel, P. “Ring Attention with Blockwise Transformers for Near-Infinite Context”. ICLR 2024.

41 of 138

Evaluation of max sequence length

40

  • Blockwise Transformer + RingAttention allows arbitrarily large context

Liu, H., Abbeel, P. “Blockwise Parallel Transformer for Large Context Models”. NeurIPS 2023 Spotlight.

Liu, H., Zaharia, M., Abbeel, P. “Ring Attention with Blockwise Transformers for Near-Infinite Context”. ICLR 2024.

🡨 512 times longer context than blockwise transformers; 2048 times longer than flash attention🡪

42 of 138

Large World Model

41

  • Modeling million-length text and video

“World Model on Million-Length Video and Language with RingAttention”. (2024) largeworldmode.github.io

43 of 138

1M effective context

42

  • LWM gets near perfect accuracy on the popular “needle in a haystack” task.

LWM achieves highly effective context over 1M tokens. No “lost in the middle” observed.

“World Model on Million-Length Video and Language with RingAttention”. (2024) largeworldmode.github.io

44 of 138

1M effective context

43

  • Near perfect accuracy on the popular “needle in a haystack” task.
    • Outperform Gemini 1.0 Pro (max 32K)
    • Outperform GPT4 (max 128K)

“World Model on Million-Length Video and Language with RingAttention”. (2024) largeworldmode.github.io

45 of 138

Large World Model: Video Generation

44

  • Large world model can do image / video / text understanding and generation

“World Model on Million-Length Video and Language with RingAttention”. (2024) largeworldmode.github.io

46 of 138

Large World Model: Video Generation

45

  • Large world model can do image / video / text understanding and generation

“World Model on Million-Length Video and Language with RingAttention”. (2024) largeworldmode.github.io

47 of 138

Large World Model: Hour-Long Video Chat

46

  • Large world model can do any to any of image / video / text understanding and generation

“World Model on Million-Length Video and Language with RingAttention”. (2024) largeworldmode.github.io

48 of 138

Scaling of long context

47

Gemini 1.5: Unlocking multimodal understanding across millions of tokens of context. 2024

  • Loss goes down with longer context

49 of 138

Understand code repo

48

  • Large context allows understanding code repositories.

Gemini 1.5: Unlocking multimodal understanding across millions of tokens of context. 2024

50 of 138

Large context applications

49

  • Large context allows understanding complex document

Gemini 1.5: Unlocking multimodal understanding across millions of tokens of context. 2024

51 of 138

Large context applications

50

  • Large context allows understanding complex document

Gemini 1.5: Unlocking multimodal understanding across millions of tokens of context. 2024

52 of 138

Large context applications

51

  • Gemini 1.5 outperforms Whisper + GPT4

Gemini 1.5: Unlocking multimodal understanding across millions of tokens of context. 2024

53 of 138

Large context applications

52

  • Translate low-resource Kalamang language based on Grammar book

Gemini 1.5: Unlocking multimodal understanding across millions of tokens of context. 2024

54 of 138

Reduce data movement

53

  • Flash Attention. Idea: minimizing communication between HBM and SRAM
  • Implementation of memory efficient attention in CUDA

"Flashattention: Fast and memory-efficient exact attention with io-awareness.” (2022)

55 of 138

Optimization

54

  • Query loop is parallelable be the outer loop.
  • Key-value loop be the inner loop.

"Flashattention-2: Faster attention with better parallelism and work partitioning.” (2023)

56 of 138

Higher throughput

55

  • FlashAttention outperform standard attention computation
  • FlashAttention-2 further outperform FlashAttention significantly

"Flashattention: Fast and memory-efficient exact attention with io-awareness.” (2022)

"Flashattention-2: Faster attention with better parallelism and work partitioning.” (2023)

57 of 138

Tool use and retrieval

56

  • Up to date information
    • LM needs access to search to answer questions about news
  • More factual knowledge
    • Some questions require factual knowledge that may not encoded in weights
  • External tools
    • Accessing calculator, calendar, and tools can make LM more capable

58 of 138

Tool use and retrieval

57

  • Improving language models by retrieving tokens

Borgeaud, Sebastian, et al. "Improving language models by retrieving from trillions of tokens." PMLR, 2022.

59 of 138

Tool use and retrieval

58

  • Format of the retrieval neighbors: [N, F] where N is used as key and F is the continuation of N.
  • Metric: d(C, N) = ||BERT(C) - BERT(N)||. RET(C) = ([N^1, F^1], …, [N^k, F^k]).

Borgeaud, Sebastian, et al. "Improving language models by retrieving from trillions of tokens." PMLR, 2022.

60 of 138

Tool use and retrieval

59

  • Input chunks: Divide input of length 2048 into chunks of length 64. N, F in the retrieval database are also of length 64.

Borgeaud, Sebastian, et al. "Improving language models by retrieving from trillions of tokens." PMLR, 2022.

61 of 138

Tool use and retrieval

60

  • RETRO leads to better performance

Borgeaud, Sebastian, et al. "Improving language models by retrieving from trillions of tokens." PMLR, 2022.

62 of 138

Tool use and retrieval

61

  • Biggest gain from Github and PG19 (books)

63 of 138

Alternative: sliding window attention

62

Child, Rewon, et al. "Generating long sequences with sparse transformers." (2019).

Jiang, A. Q., Sablayrolles, A., Mensch, A., Bamford, C., Chaplot, D. S., Casas, D. D. L., ... & Sayed, W. E. (2023). Mistral 7B. (2023)

Beltagy, I., Peters, M. E., & Cohan, A. Longformer: The long-document transformer. (2020)

  • Attending to a fixed amount of past tokens
    • Expand to further tokens explicitly in deeper layers

64 of 138

Alternative: state space model

63

Gu et al. “Efficiently Modeling Long Sequences with Structured State Spaces”. (2022)

  • Recurrent architecture
    • Promising idea with faster inference than transformer and faster training than RNN
    • Unclear how it would scale compared with transformer

65 of 138

Alternative: state space model

64

Gu, A., & Dao, T. “Mamba: Linear-time sequence modeling with selective state spaces”. (2023)

  • Perplexity – FLOPs comparison between SSM and Transformer

66 of 138

Alternative: attention + SSM

65

“Simple linear attention language models balance the recall-throughput tradeof” (2024)

  • Combine compression and attention
    • Sliding window attention to model long range dependency
    • State space model for inference efficiency

67 of 138

Based

66

  • Fixed hidden size SSM underperforms Transformer
  • SSM plus attention can match Transformer on some tasks

“Simple linear attention language models balance the recall-throughput tradeof” (2024)

68 of 138

Griffin

67

"Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models." (2024).

  • Fixed hidden size SSM underperforms Transformer
  • SSM plus attention can match Transformer on some tasks

69 of 138

Griffin

68

"Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models." (2024).

  • Pure SSM underperforms Transformer
  • Adding attention alleviate performance gap

70 of 138

Griffin

69

"Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models." (2024).

  • Faster inference due to small to no linear growing KV-cache

71 of 138

KV cache

70

  • Input prompt = “what components does transformer have?”
  • Autoregressive output:
    • “It has FFN and attention”
    • Each word (query) only needs to attend to a cached input prompt (key-value)

72 of 138

KV cache

71

  • Reduce compute cost

73 of 138

KV cache

72

  • Reduce compute cost

74 of 138

Compute and memory bandwidth

73

  • Gap between compute and memory bandwidth increasing
  • Inference is dominated by loading KV cache

75 of 138

KV cache compression

74

  • Key-value cache dominates memory cost

Hooper, Coleman, et al. "KVQuant: Towards 10 Million Context Length LLM Inference with KV Cache Quantization.” (2024)

76 of 138

KV cache compression

75

Hooper, Coleman, et al. "KVQuant: Towards 10 Million Context Length LLM Inference with KV Cache Quantization.” (2024)

  • Compressing KV cache without performance degradation

77 of 138

Multi-query attention

76

  • In MQA, there is only one query.

78 of 138

Group query attention

77

  • In GQA, each query attends to a group of key-value

79 of 138

Trade off

78

  • GQA trade off inference speed and performance

80 of 138

Group query attention

79

  • Continue training to convert MHA to GQA to MQA

81 of 138

DRAM stacking on GPU

80

  • 3D-stacks HBM memory directly on top of the processing cores
  • Both compute and memory bandwidth scale with GPU die area

82 of 138

Mixture of experts

81

Shazeer, Noam, et al. "Outrageously large neural networks: The sparsely-gated mixture-of-experts layer." (2017).

  • Select a subset of FFNs to execute.
    • Decouple compute from parameters.

83 of 138

Mixtral 8x7B

82

  • Better performance

Mistral AI. "Mixtral of experts." (2024).

84 of 138

Mixtral 8x7B

83

Mistral AI. "Mixtral of experts." (2024).

  • Same active params, better performance

85 of 138

Mixtral 8x7B

84

  • Expert selection may not be interpretable

Mistral AI. "Mixtral of experts." (2024).

86 of 138

Reasoning

85

Cobbe, K., Kosaraju, V., Bavarian, M., Chen, M., Jun, H., Kaiser, L., ... & Schulman, J. Training verifiers to solve math word problems. (2021)

  • Simple arithmetic problems

87 of 138

Finetuning requires lots of data

86

Cobbe, K., Kosaraju, V., Bavarian, M., Chen, M., Jun, H., Kaiser, L., ... & Schulman, J. Training verifiers to solve math word problems. (2021)

  • Not scalable to rely on humans to curate reasoning data
  • to achieve > 80%, needs 100 times more fine-tuning data for 175B model

88 of 138

Scratchpad

87

Nye, Maxwell, et al. "Show your work: Scratchpads for intermediate computation with language models.” (2021)

  • Scratchpad for finetuning on step by step reasoning

89 of 138

Scratchpad

88

Nye, Maxwell, et al. "Show your work: Scratchpads for intermediate computation with language models.” (2021)

  • Scratchpad for finetuning on step by step reasoning
    • Improved reasoning

90 of 138

CoT

89

Wei, Jason, et al. "Chain-of-thought prompting elicits reasoning in large language models." (2022)

  • Chain-of-thought prompting

91 of 138

CoT

90

Wei, Jason, et al. "Chain-of-thought prompting elicits reasoning in large language models." (2022)

  • Chain-of-thought prompting
    • Improved reasoning

92 of 138

Zero-shot CoT

91

Kojima, T., Gu, S. S., Reid, M., Matsuo, Y., & Iwasawa, Y. Large language models are zero-shot reasoners. (2022).

  • Chain-of-thought prompting without manual examples

93 of 138

Zero-shot CoT

92

Kojima, T., Gu, S. S., Reid, M., Matsuo, Y., & Iwasawa, Y. Large language models are zero-shot reasoners. (2022).

  • Chain-of-thought prompting without manual examples

94 of 138

Zero-shot CoT

93

Kojima, T., Gu, S. S., Reid, M., Matsuo, Y., & Iwasawa, Y. Large language models are zero-shot reasoners. (2022).

  • Chain-of-thought prompting without manual examples

95 of 138

Zero-shot CoT

94

Kojima, T., Gu, S. S., Reid, M., Matsuo, Y., & Iwasawa, Y. Large language models are zero-shot reasoners. (2022).

  • Chain-of-thought prompting without manual examples

96 of 138

Process feedback

95

Lightman, H., Kosaraju, V., Burda, Y., Edwards, H., Baker, B., Lee, T., ... & Cobbe, K. Let's Verify Step by Step. (2023).

  • Let’s verify step by step

97 of 138

Process feedback

96

Lightman, H., Kosaraju, V., Burda, Y., Edwards, H., Baker, B., Lee, T., ... & Cobbe, K. Let's Verify Step by Step. (2023).

  • Process supervision does not saturate early
  • However, process supervision is very expensive

98 of 138

RLHF

97

  • Collecting demonstration for supervised finetuning
  • Train a reward model for reinforcement learning

Ouyang, Long, et al. "Training language models to follow instructions with human feedback." (2022)

99 of 138

RLHF

98

  • RLHF outperforms SFT and scales well to larger model size

Ouyang, Long, et al. "Training language models to follow instructions with human feedback." (2022)

100 of 138

Code generation loss scaling

99

  • Scaling works for code synthesis too
  • Test loss shows power-law w.r.t model parameters

Chen, Mark, et al. "Evaluating large language models trained on code." (2021).

101 of 138

Code generation accuracy scaling

100

  • Pass@1 and pass@100 both increase with larger model

Chen, Mark, et al. "Evaluating large language models trained on code." (2021).

102 of 138

Large context code loss

101

The Claude 3 Model Family: Opus, Sonnet, Haiku. 2024

  • Power-law on the context size axis too
  • Loss on code are generally lower compared with on text

103 of 138

AlphaCode

102

  • Inference optimization by large scale sampling and filtering

“Competition-Level Code Generation with AlphaCode” (2022)

“AlphaCode 2 Technical Report” (2023)

104 of 138

AlphaCode

103

  • Inference optimization by large scale sampling and filtering

“Competition-Level Code Generation with AlphaCode” (2022)

“AlphaCode 2 Technical Report” (2023)

105 of 138

AlphaCode

104

  • Inference optimization by large scale sampling and filtering

“Competition-Level Code Generation with AlphaCode” (2022)

“AlphaCode 2 Technical Report” (2023)

106 of 138

AlphaCode2

105

  • AlphaCode2 replaces base model with Gemini pro
  • A good pretrained model plus inference time sampling

“Competition-Level Code Generation with AlphaCode” (2022)

“AlphaCode 2 Technical Report” (2023)

107 of 138

Pretraining data

106

  • Most data is from common crawl

Gao, L., Biderman, S., Black, S., Golding, L., Hoppe, T., Foster, C., ... & Leahy, C. The pile: An 800gb dataset of diverse text for language modeling.  (2020).

108 of 138

Filtering data

107

  • Llama’s dataset is calibrated toward Wikipedia
  • Outperforming GPT3 and other models

“LLaMA: Open and Efficient Foundation Language Models” (2023).

109 of 138

RedPajam

108

  • RedPajam: an open-source dataset

110 of 138

OpenLLaMA

109

  • OpenLLaMA matches LLaMA performance

“OpenLLaMA: An Open Reproduction of LLaMA” (2024)

111 of 138

TPU / GPU

110

  • Flops, HBM size and bandwidth, Interconnect

112 of 138

GEMM

111

  • High-computational throughput on neural network calculations
  • GPU compute unit is smaller but more, TPU has large compute units

TPU, systolic array 8x128x128

GPU, many 8x4x8 ALU

113 of 138

Matmul sharding

112

  • Simple case: sharded matmul Y = XA
  • Row sharding on A, column sharding on X

Shoeybi M, Patwary M, Puri R, LeGresley P, Casper J, Catanzaro B. Megatron-lm: Training multi-billion parameter language models using model parallelism. (2019)

114 of 138

Matmul sharding

113

  • Row sharding on A, column sharding on X

115 of 138

Matmul sharding

114

  • Row sharding on A, column sharding on X

116 of 138

Matmul sharding

115

  • Row sharding on A, column sharding on X

117 of 138

Matmul sharding

116

  • Row sharding on A, column sharding on X

118 of 138

Matmul sharding

117

  • Column sharding on A
  • Replicate X

119 of 138

Matmul sharding

118

  • Column sharding on A
  • Replicate X

120 of 138

Matmul sharding

119

121 of 138

MLP sharding

120

  • Y = GELU(XA)B
  • Approach 1: split X column-wise and A row-wise
    • GeLU of sums != sum of GeLUs
    • Requires synchronization before GeLU
  • Approach 2: split A column-wise
    • No sharding on contraction dimension of X

    • No synchronization necessary

122 of 138

MLP sharding

121

  • f and g are conjugate, f is identity operator in the forward pass and all-reduce in the backward pass while g is all-reduce in forward and identity in backward.

123 of 138

Attention sharding

122

  • Nonlinearity in attention, so column sharding.

124 of 138

Attention sharding

123

  • f and g are conjugate, f is identity operator in the forward pass and all-reduce in the backward pass while g is all-reduce in forward and identity in backward.

125 of 138

Attention sharding

124

  • f and g are conjugate, f is identity operator in the forward pass and all-reduce in the backward pass while g is all-reduce in forward and identity in backward.

126 of 138

Data parallelism

125

  • Activations need to be sharded – they’re much bigger that weights. We can either shard them by rows or columns.
  • It keeps working with a linear speedup as long as you scale batch linearly with number of TPUs

Rafi W. “Sharding Techniques Single Slice Sharding For Dense LLMs”

127 of 138

Data parallelism

126

  • At 4 bytes per parameter we need 700GB for 175B GPT3
  • But A100 has only 80GB.

Rafi W. “Sharding Techniques Single Slice Sharding For Dense LLMs”

128 of 138

Fully sharded data parallelism

127

  • All-gathering the next layer while doing the arithmetic on this layer
  • Requires storing only 1 layer at a time (next slide)

Rafi W. “Sharding Techniques Single Slice Sharding For Dense LLMs”

129 of 138

Fully sharded data parallelism

128

  • All-gathering the next layer while doing the arithmetic on this layer
  • Requires storing only 1 layer at a time

Rafi W. “Sharding Techniques Single Slice Sharding For Dense LLMs”

130 of 138

Fully sharded data parallelism

129

  • The matrix multiply takes: FLOPs = 2 * BATCH_PER_CHIP * E2
  • The all gather requires reading ~2 * E2 bytes (assuming bfloat16)
  • Arithmetic intensity (FLOPs/bytes) = 2 * BATCH_PER_CHIP * E2 / (2 * E2 bytes) = BATCH_PER_CHIP FLOPs / byte.
  • On TPU, ICI arithmetic intensity is 1018 FLOPs/byte, so we need our BATCH_PER_CHIP to be comfortably larger than 1018.

Rafi W. “Sharding Techniques Single Slice Sharding For Dense LLMs”

131 of 138

Tensor parallelism

130

  • Shard the activation dimension

Rafi W. “Sharding Techniques Single Slice Sharding For Dense LLMs”

132 of 138

Tensor parallelism

131

  • Shard the activation dimension

Rafi W. “Sharding Techniques Single Slice Sharding For Dense LLMs”

133 of 138

Tensor parallelism

132

  • The matrix multiply takes (per chip): FLOPs = 2 * BATCH * E2 / (NUM_TP_SHARDS)
  • The reduce scatter requires ~2 * BATCH * E bytes (assuming bfloat16)
  • Arithmetic intensity (FLOPs/bytes) = E / NUM_TP_SHARDS.
  • TPU ICI arithmetic intensity is 1018 FLOPs/byte, so we need E to be comfortably larger 1018 * NUM_TP_SHARDS.

Rafi W. “Sharding Techniques Single Slice Sharding For Dense LLMs”

134 of 138

FSDP with TP

133

  • While processing layer i, all-gather layer weights for layeri+1 for FSDP
  • Then it just becomes TP

135 of 138

FSDP with TP

134

  • While processing layer i, all-gather layer weights for layeri+1 for FSDP
  • Then it just becomes TP

136 of 138

FSDP with TP

135

  • Given NUM_TP_SHARDS and NUM_FSDP_SHARD
  • When both:
    • E > 2 * 1018 * NUM_TP_SHARDS
    • NUM_TP_SHARD * BATCH_PER_CHIP > 2 * 1018
  • 256 TPU: 16 for TP and 16 for FSDP
    • E > ~32k
    • BATCH_PER_CHIP > 125

137 of 138

Practical guideline

136

  • Scaling model size and data size
    • Use FSDP as much as possible
    • Use TP/DP/SP to reduce batch size
  • Scaling context size
    • Use Blockwise Transformers to reduce memory cost
    • Use RingAttention to increase context arbitrarily

138 of 138

Open problems

137

  • Large language model 🡪 large world model
    • Predict world realistically, including text, vision, action.
    • Video conversation
  • Consistent reasoning
    • With million-length “tape” but models stop step-by-step thinking after one hundred steps.
  • Training and inference efficiency
    • Exponentially more efficient architecture or training paradigm