New Pre-training Paradigms from a Inference-First Perspective
Jiaming Song
�Luma AI
Video Models’ 1 year birthday
Native Multi-Modal Generation�The current “hype”
From
To:
A lot of papers in the past year
Text: Discrete AR
Image: Discrete AR
Text: Discrete AR
Image: Discrete Diffusion
Text: Discrete Diffusion
Image: Discrete Diffusion
Text: Discrete AR
Image: Continuous Diffusion
All based on combinations of Discrete AR / Discrete Diffusion / Continuous Diffusion etc..
Why not stick to next-token prediction?
Discrete tokens have a quality issue
Original
Reconstructed
It looks quite different up close!
*You cannot use this to know who they are even for understanding purposes
Discrete tokens have a quality issue
Discrete tokens have much worse reconstruction than continuous ones
https://arxiv.org/abs/2408.06072
https://arxiv.org/abs/2409.18869
Discrete
Continuous
Fundamental flaw of discrete tokens
Discrete tokens have to compress a lot more for the same sequence length
Continuous tokens has much higher quality in the same sequence length!
Bit compression = (4 * 8 * 8) * 3 * 8 / 15 = 409.6
Sequence Compression
Channels
8 bit color
log2(32768)
Bit compression
= (4 * 8 * 8) * 3 * 8 / (16 * 8) = 48
bfloat16
Latent channels
Continuous tokens have a speed issue
Diffusion requires many timesteps to converge
BAGEL: MoT with discrete + continuous tokens
Continuous tokens have a speed issue
D1
C1
D2
C2
Discrete tokens
D1
C1
Discrete tokens only requires 1 pass of the transformer
Continuous tokens requires many passes of the transformer.
C1
C1
C1
C1
…
Continuous signal.
Can be image / video / sound / actions etc…
While the sequence looks like this
The compute on the hardware is really like this!
The algorithms are dominated by AR and diffusion…
But none are perfect!
The trilemma of continuous generative models
Training stability
High quality samples
Efficient inference
GANs, Diffusion Distillation
Diffusion Models
VAEs, Normalizing Flows
Need something here!
Is there anything that would break the ceiling of the two?
The answer is Yes!
The algorithms are dominated by AR and diffusion…
But none are perfect!
Outline
How can we scale at inference-time?
Inference-Time Scaling in Sequence Length
Increases the number of tokens
Inference-time Scaling in Refinement Steps
Does not increase the number of tokens
“puppy in space”
Categorizing existing algorithms
A lot of algorithms that scale in both axes
A lot of algorithms that scale in both axes
Scaling efficiency in inference algorithm
Of course, just being able to scale up is not enough!
We also have to scale efficiently!
Infinite monkeys “can” type Shakespeare
AlphaGo enabled by how to search more efficiently
Three positions
1. The right inference algorithm should scale in both axes.
2. Assuming that the model has enough capacity (under universal approximation theorem), it should use as few steps as possible.
3. Analyze the inference algorithm before the training algorithm!
(Applies to continuous and discrete cases, but will focus on continuous today)
Application to Continuous Diffusion
2. Assuming that the model has enough capacity (under universal approximation theorem), it should use as few steps as possible. (𝘟)
Application to Continuous Diffusion
What do we want from the ”right” inference algorithm?
There exists a solution to the model such that both holds:
Unfortunately, DDIM is NOT the “right” inference algorithm!
DDIM and the Inference Capacity Issue
The Fix
Diffusion Models and Flow Matching
Application to Continuous Diffusion
DDIM is NOT the “right” inference algorithm because model only takes a single timestep!
We can fix it by asking the model to take 2 timesteps!
Analyze inference before training
Once the inference algorithms is decided, it can be trained with many different approaches!
Inductive Moment Matching
Intuition: ”consistency” in distributions
For timesteps s < r < t, the two distributions should be close:
Intuition: ”consistency” in distributions
We can simply use Maximum Mean Discrepancy (MMD):
Advantages of IMM
Stable Training
Image Generation
Scaling Property
Advancing Efficiency / Quality Frontier
The trilemma of continuous generative models
Training stability
High quality samples
Efficient inference
GANs, Diffusion Distillation
Diffusion Models
VAEs, Normalizing Flows
IMM
(and possibly other flow map methods)
Applications to Discrete Diffusion
Consider Masked Diffusion, a performant variant of discrete diffusion
Shi et al., Simplified and Generalized Masked Diffusion for Discrete Data
Applications to Discrete Diffusion
In mask diffusion, value changes only when input is [mask] token.
Suppose seqlen = N, and we want to sample in L << N steps:
Shi et al., Simplified and Generalized Masked Diffusion for Discrete Data
Applications to Discrete Diffusion
Does the BERT-style model have “enough capacity”?
Suppose we try to predict:
The list of poker hands that consist of two English words are: [MASK] [MASK]
Applications to Discrete Diffusion
From the inference-first perspective:
Masked discrete diffusion might have capacity issues when trying to sample in L << N steps when using the BERT-style model, regardless how it is trained!
Takeaway
Analyze the inference algorithm before the training algorithm!
https://lumalabs.ai/join
Happy hour @ Barstool
https://lu.ma/5s0o2hlh
Join us
Learning to Take Large Strides in Time
Generalized Interpolant
Model and Sampling
Want: sample follows
Naïve objective:
2-step sample
Inductive Moment Matching
Inductive Learning Algorithm
2 particles