Why won’t Llama13B fit on my 4090?
Mark Saroufim
Sizes on disk
Model sizes
GPU VRAM
Back of the envelope
Mark has a 4090 with 23GB of VRAM
Llama 7B works
Does Llama 13B load inference only?
Ok what’s going on?
Do I need more GPUs? How many? No idea
Most outlets in the US can only handle 1.8kW so 4 GPUs in an apartment is already tricky
360 no scoping a github issue
Goal is to understand this picture and more!
https://arxiv.org/pdf/1904.10631.pdf
4 sources of memory
Inference and Training
Training only
Model memory
Gradient memory
With full finetuning model memory = gradient memory
Optimizer Memory
For GPU poor: Use SGD (or don’t, not sure why this isn’t more popular) for finetuning
For GPU rich: you can store the optimizer state, gradients and parameters across devices
Activation memory
Intermediate tensors from forward calls
For large input sequence length and large batch sizes this is the bottleneck
Some ways to resolve this by trading off speed for memory
Demo time
KV cache (That’s a lot of memory)
Tradeoff speed vs memory
Biggest values are seqlen e.g 512 or more and d_model 4096
Can be mitigated by MQA or GQA: learn more in this other self plug tutorial here where we reduce number of caches per head at the cost of training accuracy
So why does this OOM? load_in_4bit=True
Weight only quantization
Static quantization e.g smoothquant
Model quantization
Gradient quantization
Optimizer quantization
What about QLoRA
LORA increases model memory because adding adapters
QLora decreases model memory because model in nf4 but the adapters are in fp32
Decrease activation and optimizer memory because it’s over a smaller subset and its quantized
Sparsity
Reduces model, optimizer and activation
How much VRAM exactly?
Sparsity back of the envelope
For each non sparse element you store the rowptr, col, value each is 32 bits
So with 0% sparsity: (32*3) * N where N is the number of weights
So with 90% sparsity:
Memory usage (dense)−Memory usage (CSR)=32N−(6.4N+32M)
N = O(M^2)
So can simplify to 32N - 6.4N = 25.6N -> 80% memory savings
At 50% sparsity: 32N - 32N + 32M = 32M so memory increase XD but can be accelerated by NVIDIA
TL;DR
Think about the 4 kinds of memory and where your technique helps
References
https://tinkerd.net/blog/machine-learning/distributed-training/
https://arxiv.org/pdf/1904.10631.pdf
Thanks to Christian and Vasily for answering my questions on sparsity and quantization respectively
Thanks to Anton for recommending the tinkerd blog