1 of 40

Multi-student Diffusion Distillation

Yanke Song, Jonathan Lorraine, Weili Nie, Karsten Kreis, James Lucas

12/14/2024

NVIDIA Confidential

2 of 40

The next step for Gen AI: speed

  • Target applications:
    • Real-time video generation
      • Holodeck, DLSS, generative games
    • Interactive (multimodal) LLMs
  • Conflicting factors:
    • High-quality
    • Low latency

3 of 40

The Problem

  • Diffusion models are backbone of SOTA generative models
  • However, often require:
    • A lengthy iterative denoising process
    • Large models, with a slow, memory-intensive forward pass

4 of 40

Diffusion Acceleration

  • Model pruning: Diff-Pruning
  • Engineering efforts: StreamDiffusion
  • Better numerical methods: DPM-Solver, GENIE, ParaDiGMS
  • Knowledge Distillation
    • Earlier attempts (PD, DSNO, RectifiedFlow)
    • Consistency models (CM, iCM, LCM, CTM)
    • Distribution matching (DMD, DMD2, EMD)
    • Adversarial loss (ADD, LADD, DMD2)

5 of 40

How fast is fast enough?

  • Goal: real-time video generation
  • 5s, high-quality, 720p video: currently takes tens of minutes
  • Need single-step distillation
  • Need smaller architecture
  • Capacity --, performance --
  • How to maintain capacity?

6 of 40

Is 1-step fast enough?

  • Simply put: No.
    • Distillation needs same architecture for student vs. teacher
    • Can still take 10s of seconds.
  • Need smaller architecture
    • Less capacity -> worse performance
    • Can we maintain large capacity + lower latency?

7 of 40

Our Method

8 of 40

Distilling many small 1-step generators

  • We distill into multiple single-step students:
    • Each student trained on data subsets
    • Effective capacity ++; Latency --

9 of 40

Distilling many small 1-step generators

  • Each student focus on data subsets
  • User request -> routing -> specialized student
  • Capacity ++, latency --

10 of 40

How to choose the right student?

  • Strategy 1: manual partitioning

Student 2

Student 1

Student 3

11 of 40

How to choose the right student?

  • Strategy 2: embedding space clustering

12 of 40

How to choose the right student?

  • AV applications: different cities/regions
  • DLSS: different games

13 of 40

Training many small 1-step generators

  • For training
    • Same routing strategy, paired data
    • Parallel training

14 of 40

Relationship to mixture-of-experts

  • Distillation + MoE since the start (e.g. in Hinton’s 2015 paper)
  • MoE for discriminative model:
    • Simple outputs
    • Weighted combination
  • Independent experts for generative model:
    • Complex outputs
    • Hard to combine

15 of 40

Some Technical Details for Our Method

  • Distribution matching distillation (DMD, DMD2)
    • Student (fake) distribution -> Teacher (real) distribution

16 of 40

Some Technical Details for Our Method

  • Distribution matching distillation (DMD, DMD2)
    • Distance for distribution: KL divergence
    • Gradient: score functions

17 of 40

Some Technical Details for Our Method

  • Distribution matching distillation (DMD, DMD2)
    • Real score: teacher diffusion model
    • Fake score: additional “fakescore” model

18 of 40

Some Technical Details for Our Method

  • Adversarial distillation (DMD2)
    • GAN loss: enhance sharpness and realism

19 of 40

What we have now

  • Teacher -> multiple single-step students
  • Techniques we use: distribution matching distillation + adversarial distillation
  • Same architecture

20 of 40

Are we done?

  • Same architecture: initialize teacher weights
  • Smaller architecture: random initialization

21 of 40

Some Technical Details for Our Method

  • Solution: teacher score-matching
    • Student scores -> teacher scores
    • Equivalently: distill a smaller diffusion model first
    • Learns good feature maps/initialization weights

Score-function for Mixture of Gaussian

22 of 40

Some Technical Details for Our Method

  • Solution: teacher score-matching
    • Equivalently: distill a multi-step student diffusion model first
    • Size + step distillation >> step + size distillation
    • Step-distillation closes the gap

23 of 40

Some Technical Details for Our Method

  • Three stage strategy:

24 of 40

Results

25 of 40

Experiments

SD, dogs

26 of 40

Experiments

ImageNet

27 of 40

Experiments

FID scores

Class conditional: ImageNet-64x64

Text-to-image: COCO2014

28 of 40

Experiments

Ablation studies

  • MSD is more than batch size increase
  • More students get better performance
  • Partitions should be semantically similar

29 of 40

Summary

  • We present MSD, a framework that:
    • Distill into specialized students
    • # students ++, capacity ++, robust boost
    • Student size --, latency --
    • Flexibility
  • Generality:
    • Works with different distillation methods
    • Compatible with other acceleration techniques (i.e. pruning)

30 of 40

Conclusion & Next Steps

  • Summary:
    • Multi-student distillation (MSD)
      • Focus on data subsets
      • With model size -> universal performance boost
      • Smaller model size -> faster inference

31 of 40

Conclusion & Next Steps

  • Limitations / future work
    • Better way to reduce model size (i.e. pruning)
    • Collaboration (i.e. weight sharing, hierarchical?)

32 of 40

Questions

33 of 40

Spare Slides

34 of 40

Spare Slides

35 of 40

Diffusion Acceleration

36 of 40

Diffusion Acceleration

37 of 40

Diffusion Acceleration

    • Rectified Flow: Straighten paths
      • InstaFlow, PeRFlow

38 of 40

Diffusion Acceleration: SOTA

  • Distribution Matching
    • Single step student
    • Minimize the KL[student | teacher]
    • Score matching over diffused distributions
    • “Fake” score model to estimate fake data score
  • Adversarial Distillation
    • An adversarial loss: teacher vs student
    • Discriminate over diffused distributions
    • Discriminator + fake score: shared weights

39 of 40

Distribution matching distillation (DMD)

40 of 40

Distribution matching distillation 2 (DMD2)