1 of 19

Why won’t Llama13B fit on my 4090?

Mark Saroufim

2 of 19

Sizes on disk

Model sizes

  • Llama 7b: 13GB
  • Llama 13B: 25GB
  • Llama 70B: 130GB

GPU VRAM

  • 3090/4090: 24GB
  • A6000: 48GB
  • A10G: 24GB
  • A100: 40GB or 80GB

3 of 19

Back of the envelope

Mark has a 4090 with 23GB of VRAM

Llama 7B works

Does Llama 13B load inference only?

  • Nvidia-smi says 700MB for display drivers
  • Load in 4 bit so 25/4 = 6.2GB!
  • OOM

Ok what’s going on?

4 of 19

Do I need more GPUs? How many? No idea

Most outlets in the US can only handle 1.8kW so 4 GPUs in an apartment is already tricky

5 of 19

360 no scoping a github issue

6 of 19

Goal is to understand this picture and more!

https://arxiv.org/pdf/1904.10631.pdf

7 of 19

4 sources of memory

Inference and Training

  • Model memory: model weights
  • Activation memory: intermediate values during forward pass

Training only

  • Gradient memory: gradients for each weight
  • Optimizer memory: additional state required by optimizer

8 of 19

Model memory

9 of 19

Gradient memory

With full finetuning model memory = gradient memory

10 of 19

Optimizer Memory

  • SGD: No extra memory
  • SGD with momentum = model memory
  • Adam = 2 * model memory

For GPU poor: Use SGD (or don’t, not sure why this isn’t more popular) for finetuning

For GPU rich: you can store the optimizer state, gradients and parameters across devices

11 of 19

Activation memory

Intermediate tensors from forward calls

For large input sequence length and large batch sizes this is the bottleneck

Some ways to resolve this by trading off speed for memory

  • Activation checkpointing: don’t save all intermediate forwards, recompute
  • Microbatching: Instead of bs=10 do 2 batches of bs=5 and update weights once

12 of 19

Demo time

13 of 19

KV cache (That’s a lot of memory)

Tradeoff speed vs memory

Biggest values are seqlen e.g 512 or more and d_model 4096

Can be mitigated by MQA or GQA: learn more in this other self plug tutorial here where we reduce number of caches per head at the cost of training accuracy

14 of 19

So why does this OOM? load_in_4bit=True

Weight only quantization

  • Saves disk space
  • Saves memory bandwidth, especially important at bs=1
  • Computation still happening in fp16 so doesn't save activation memory or model memory

Static quantization e.g smoothquant

  • Quantizes activation

Model quantization

  • Need both hardware and framework support

Gradient quantization

  • Float8

Optimizer quantization

  • Research

15 of 19

What about QLoRA

LORA increases model memory because adding adapters

QLora decreases model memory because model in nf4 but the adapters are in fp32

Decrease activation and optimizer memory because it’s over a smaller subset and its quantized

16 of 19

Sparsity

Reduces model, optimizer and activation

How much VRAM exactly?

17 of 19

Sparsity back of the envelope

For each non sparse element you store the rowptr, col, value each is 32 bits

So with 0% sparsity: (32*3) * N where N is the number of weights

So with 90% sparsity:

  • Row pointer not dependent on sparsity amount but row count M
  • Value and Column 64 * 0.1N = 6.4N

Memory usage (dense)−Memory usage (CSR)=32N−(6.4N+32M)

N = O(M^2)

So can simplify to 32N - 6.4N = 25.6N -> 80% memory savings

At 50% sparsity: 32N - 32N + 32M = 32M so memory increase XD but can be accelerated by NVIDIA

18 of 19

TL;DR

Think about the 4 kinds of memory and where your technique helps

19 of 19

References

https://tinkerd.net/blog/machine-learning/distributed-training/

https://arxiv.org/pdf/1904.10631.pdf

Thanks to Christian and Vasily for answering my questions on sparsity and quantization respectively

Thanks to Anton for recommending the tinkerd blog