Inference Considerations
for LLMs
May 2024
Overview
Topics we’ll cover
Background:
The Transformer Architecture
Self-attention - the key to it all
Self attention provides:
Self-attention
Self attention provides:
Self-attention
Self attention provides:
Self-attention
Self attention provides:
Consider the word “bank” in these two sentences:
The Transformer
The Transformer
The Transformer
The Transformer
The Transformer
Feed-forward network
Feed-forward network
The Transformer
Translation Example
Translation Example
(3 x 512)
(3 x 512)
Translation Example
(3 x 512)
(3 x 512)
(1 x 512) - <s>
Translation Example
(3 x 512)
(3 x 512)
(2 x 512) - [<s>, 297]
Translation Example
Let’s recap
Encoder
Decoder
Different models use different parts
A rapidly evolving ecosystem
Decoder Models
Decoders
Decoders
$
recite
the
first
law
Input
Inspiration: The Illustrated GPT-2
LLMs
$
recite
the
first
law
LLM
Input
Inspiration: The Illustrated GPT-2
LLMs
$
recite
the
first
law
LLM
Input
Inspiration: The Illustrated GPT-2
LLMs
$
recite
the
first
law
LLM
Input
Inspiration: The Illustrated GPT-2
LLMs
$
recite
the
first
law
LLM
Input
A
Inspiration: The Illustrated GPT-2
LLMs
A
$
recite
the
first
law
LLM
Input
A
Inspiration: The Illustrated GPT-2
LLMs
A
$
recite
the
first
law
LLM
Input
A
Inspiration: The Illustrated GPT-2
LLMs
A
$
recite
the
first
law
LLM
Input
A
robot
Inspiration: The Illustrated GPT-2
LLMs
robot
A
$
recite
the
first
law
LLM
Input
A
robot
Inspiration: The Illustrated GPT-2
LLMs
robot
A
$
recite
the
first
law
LLM
Input
A
robot
may
Inspiration: The Illustrated GPT-2
LLMs
may
robot
A
$
recite
the
first
law
LLM
Input
A
robot
may
Inspiration: The Illustrated GPT-2
LLMs
may
robot
A
$
recite
the
first
law
LLM
Input
A
robot
may
not
Inspiration: The Illustrated GPT-2
LLMs
may
robot
A
$
recite
the
first
law
LLM
Input
A
robot
may
not
not
Inspiration: The Illustrated GPT-2
LLMs
Decoding Methods: How LLMs Generate Text
Recall the decoder’s final output layer
… a probability distribution over all tokens in the vocabulary
The dog [MASK]
Greedy Search
Simplest decoding approach - select the word with the highest probability as its next word
Greedy Search
Simplest decoding approach - select the word with the highest probability as its next word
The
dog (0.4)
nice (0.5)
car (0.1)
woman (0.4)
house (0.3)
guy (0.3)
and (0.05)
runs (0.05)
has (0.9)
is (0.3)
drives (0.5)
turns (0.2)
Greedy Search
Simplest decoding approach - select the word with the highest probability as its next word
The
dog (0.4)
nice (0.5)
car (0.1)
woman (0.4)
house (0.3)
guy (0.3)
and (0.05)
runs (0.05)
has (0.9)
is (0.3)
drives (0.5)
turns (0.2)
Greedy Search
Simplest decoding approach - select the word with the highest probability as its next word
The
dog (0.4)
nice (0.5)
car (0.1)
woman (0.4)
house (0.3)
guy (0.3)
and (0.05)
runs (0.05)
has (0.9)
is (0.3)
drives (0.5)
turns (0.2)
Greedy Search
Simplest decoding approach - select the word with the highest probability as its next word
The
dog (0.4)
nice (0.5)
car (0.1)
woman (0.4)
house (0.3)
guy (0.3)
and (0.05)
runs (0.05)
has (0.9)
is (0.3)
drives (0.5)
turns (0.2)
?
Misses high probability sequences hidden behind low probability words
Beam Search
Keep most likely num_beams at each time step. Choose the hypothesis that has the highest overall probability.
The
dog (0.4)
nice (0.5)
car (0.1)
woman (0.4)
house (0.3)
guy (0.3)
and (0.05)
runs (0.05)
has (0.9)
is (0.3)
drives (0.5)
turns (0.2)
num_beams = 2
@ t == 1:
The dog (0.4)
The nice (0.5)
@ t == 2:
The dog has (0.36)
The nice woman (0.2)
Suffers from repetitive generation… human language doesn’t follow distribution of high probability next words… sounds boring
Sampling
Randomly pick the next word according to its conditional probability distribution
Reduce likelihood that words are repeated, however increases chance that model is “too creative” - i.e. it wanders off to words/topics that don’t make sense
Inference Parameters
Top-K Sampling
Select an output from the top-k results after applying random-weighted strategy on the redistributed probability mass
Some words might be sampled from a very sharp distribution (distribution on the right in the graph above), whereas others from a much more flat distribution (distribution on the left in the graph above).
Top-P (nucleus) Sampling
Choose from the smallest possible set of words whose cumulative probability exceeds the probability p
The size of the set of words (a.k.a the number of words in the set) can dynamically increase and decrease according to the next word's probability distribution
p=0.92
Temperature
Scaling factor that impacts the shape of probability distribution for next token
Low temperature is less random, high temperature is more random
What generation params should I use??
For more decoding strategies, including newer ones, check out this guide
Efficient Inference Techniques
Why the need?
Challenges:
Loading the weights of a model having X billion parameters requires roughly 2 * X GB of VRAM in bfloat16/float16 precision:
GPU | VRAM | ~ $/hr ($/month) |
NVIDIA Tesla T4 | 16 GB | $0.50 ($336) |
NVIDIA A10 | 24 GB | $1.00 ($672) |
NVIDIA A100 | 40 GB | 80 GB | $6.00 ($4032) |
For reference, common GPU Hardware:
1. Lower Precision
Lower Precision via Quantization
Operating at reduced numerical precision, namely 8-bit and 4-bit, can achieve computational advantages without a considerable decline in model performance.
The key is to reduce precision without compromising model expressivity / accuracy
Lower Precision via Quantization
Operating at reduced numerical precision, namely 8-bit and 4-bit, can achieve computational advantages without a considerable decline in model performance.
2. Quantize
model weights
3. Dequantize weights
and compute
4. Quantize weights again
1. Load
model weights
Input
One Transformer Layer
16 bit
8 bit
16 bit
16 bit
8 bit
We dynamically de-quantize weights on-the-fly to perform matrix multiplications in 16 bit, and then re-quantize.
Inference time is not reduced (often increases), but memory overhead is.
Memory savings:
Popular Quantization Schemes
Two integration efforts have been made and are natively supported in transformers : bitsandbytes and auto-gptq
bitsandbytes
auto-gptq
Today we also have: AWQ, EETQ, HQQ, AQLM, Quanto, gguf
2. Flash Attention
Regular Self-Attention
The traditional self attention mechanism scales quadratically in memory and compute wrt sequence length
The Q and K matrices each consist of N vectors, so QKT is size N2
Assuming the LLM has 40 attention heads and runs in bf16, the memory requirement to store QKT is:� = 40 x 2 x N2
So for:
Output O of a single self attention layer for input X of length N:
Self attention is prohibitively memory expensive for long input contexts�
How can we reduce this memory requirement?
.
.
.
Flash Attention
Flash Attention
A variation of the attention algorithm that not only provides a more memory-efficient approach but also realizes increased efficiency due to optimized GPU memory utilization.
LLM inference is memory-IO bound - i.e. takes longer to move 1MB of data to GPU compute core than it does to perform the actual computation on it.
Standard attention requires transfering intermediate values back and forth between HBM and SRAM multiple times during the computation.
Flash attention loads all of the data just once by fusing kernel operations and tiling to partition inputs for parallel processing.
VRAM scales linearly with seq. length + up to 3x faster computations
SRAM (static random access memory) - on chip
3. KV-cache
KV-Cache
Autoregressive generation works by iteratively generating tokens and adding them to the input.
KV-Cache
KV-cache saves compute resources by reusing previously calculated self-attention key-value pairs, instead of recalculating them for each generated token.
Without kv-cache
Generation steps for input of length = 20 tokens:
shape of input_ids torch.Size([1, 21])
shape of input_ids torch.Size([1, 22])
shape of input_ids torch.Size([1, 23])
shape of input_ids torch.Size([1, 24])
shape of input_ids torch.Size([1, 25])
[' Here is a Python function']
KV-Cache
KV-cache saves compute resources by reusing previously calculated self-attention key-value pairs, instead of recalculating them for each generated token.
Without kv-cache
Generation steps for input of length = 20 tokens:
shape of input_ids torch.Size([1, 21])
shape of input_ids torch.Size([1, 22])
shape of input_ids torch.Size([1, 23])
shape of input_ids torch.Size([1, 24])
shape of input_ids torch.Size([1, 25])
[' Here is a Python function']
With kv-cache
Generation steps for input of length = 20 tokens:
shape of input_ids torch.Size([1, 1])
length of key-value cache 20
shape of input_ids torch.Size([1, 1])
length of key-value cache 21
shape of input_ids torch.Size([1, 1])
length of key-value cache 22
shape of input_ids torch.Size([1, 1])
length of key-value cache 23
shape of input_ids torch.Size([1, 1])
length of key-value cache 24
[' Here', ' is', ' a', ' Python', ' function']
KV-Cache
KV-cache saves compute resources by reusing previously calculated self-attention key-value pairs, instead of recalculating them for each generated token.
Using the key-value cache has two advantages:
Extremely useful for chat and RAG applications as we don’t need to re-encode the entire history of tokens on each forward pass
4. Dynamic Batching
Naive/Static Batching
Batch size remains constant until inference is complete for each sequence in the batch
Dynamic/Continuous Batching
Also called “iteration-level scheduling” where sequences in a batch are swapped in and out per iteration to make best use of GPU memory
5. Tensor Parallelism
Tensor Parallelism
What happens when you have Llama3-70B that requires 140GB of VRAM, but largest GPU card only has 80GB?
Text Generation Inference (TGI)
Serving framework is important for performance
TGI: A Rust, Python and gRPC server for text generation inference. Used in production at Hugging Face to power LLMs
Others include: vLLM, Triton, Seldon, etc.
Thank you!