1 of 56

⚕ Caduceus Distill Report

Tabtab Labs

1

2 of 56

Agenda

  • Context
  • Experimental Results & Future Work
  • Log & Learnings
  • Appendix

Tabtab Labs

2

3 of 56

Goal:

Is it possible to leverage distillation on the Caduceus model to reduce the inference cost while keeping good performance.

Tabtab Labs

3

4 of 56

Tabtab Labs

4

Distillation

5 of 56

Tabtab Labs

5

6 of 56

Tabtab Labs

6

Caduceus Model

Distillation

Learned Model

7 of 56

Tabtab Labs

7

8 of 56

Tabtab Labs

8

9 of 56

About distillation

Tabtab Labs

9

10 of 56

About distillation

Tabtab Labs

10

"Hard" Loss

"Soft" Loss

11 of 56

About distillation

Tabtab Labs

11

L = α("Soft" Loss) + (1 - α)("Hard" Loss)

12 of 56

The impact of Temperature

Tabtab Labs

12

13 of 56

The impact of Temperature

Tabtab Labs

13

14 of 56

Tabtab Labs

14

15 of 56

Tabtab Labs

15

16 of 56

Tabtab Labs

16

17 of 56

Tabtab Labs

17

Hidden states loss

18 of 56

Tabtab Labs

18

19 of 56

Tabtab Labs

19

Cross Entropy �(Distillation Loss)

Cosine �(Hidden States Loss)

Masked Language Modeling�(Original BERT Loss)

20 of 56

Tabtab Labs

20

Caduceus

21 of 56

Tabtab Labs

21

22 of 56

Tabtab Labs

22

23 of 56

Tabtab Labs

23

24 of 56

Tabtab Labs

24

25 of 56

Tabtab Labs

25

DNA Sequence (from HG38)

Nucleotides�(V=12 Tokens!)

26 of 56

Caduceus:

  • Long(ish) sequences: 131k (217 )
  • ~8M parameters
  • 256 embedding dimensions
  • 16 Mamba blocks

Tabtab Labs

26

27 of 56

Tabtab Labs

27

Experimental results

28 of 56

Our Problem Statement

  • OP + try not to break the bank
    • i.e. we use Colab Pro to access GPUs
      • Tried Colab, Lambda, Modal, Vast

Tabtab Labs

28

29 of 56

Results

  • Caduceus distillation should work
  • We recommend using à la DistilBERT distillation

Tabtab Labs

29

30 of 56

31 of 56

“Global” validation loss

Train loss

“Local” validation loss

32 of 56

  • ~88% of the teacher performance
  • 3 term distillation exhibits better correlation between the validation loss and NT scores: -0.66 (0.10) vs -0.48 (0.07).

33 of 56

Results - (Some) Hyperparameters

  • seq-length - DNA sequence length
  • n-layers - number of MambaDNA “blocks”, aka model height
  • d-model - embedding dimension, aka model width
  • lr - learning rate
  • temperature - distillation loss softmax temperature
  • alpha-soft - weight of the soft loss
    • alpha-hard = 1 - alpha-soft, weight of the hard loss
  • alpha-sim - weight of the embedding similarity
  • batch-size/accumulate-grad-batches

Tabtab Labs

33

34 of 56

Results - (Some) Hyperparameters

  • Higher temperature doesn’t help
  • Reducing both height and width by half is too aggressive (you end up with ~15% size)
  • Consider shorter sequence lengths (< 217 bp)

Tabtab Labs

34

35 of 56

Results - Code/Links

uv run distill --help

Tabtab Labs

35

36 of 56

Results

  • Caduceus distillation should work
  • We recommend using à la DistilBERT distillation

Future Work

  • Long run with ~50% of teacher size
  • IF you can keep dim at 256, consider initializing student with teacher weights (DistilBERT)

Tabtab Labs

36

37 of 56

Log

  • Caduceus env init: caduceus#74

Tabtab Labs

37

38 of 56

Log

  • Handling of the N nucleotide, Caduceus training data
    • caduceus#77

Tabtab Labs

38

39 of 56

Log

  • PyTorch Lightning frictions

Tabtab Labs

39

40 of 56

Log

  • Teacher inference → Distillation → Evaluation
    • Materialize the teacher logits

Tabtab Labs

40

41 of 56

Log

  • Teacher inference → Distillation → Evaluation
    • Materialize the teacher logits
      • Be careful with GCS egress (cf. R2 storage)

Tabtab Labs

41

42 of 56

Log

  • Teacher inference → Distillation → Evaluation
    • Combined to avoid materialization of embeddings

Tabtab Labs

42

43 of 56

Log

  • Distillation loss tweaking
    • Combine both hard and soft loss
    • Temperature scaling
    • Nuke non-nucleotide classes (better init)
    • Support embedding similarity

Tabtab Labs

43

44 of 56

Log

  • Hyperparameter tuning
    • Scaling laws

Tabtab Labs

44

45 of 56

Log

  • W&B logging
    • Log gradients
      • Our custom version (norm and update ratio)
      • W&B builtin

wandb.watch(model=model, log="all", log_freq=...)

    • Log distillation parameters
    • Log all metrics

Tabtab Labs

45

46 of 56

Log

  • pytorch.compile
    • Issues with mamba_ssm
    • What worked did not yield perf improvements

Tabtab Labs

46

47 of 56

Log

  • Problematic Batch 8590 (caduceus-distill#38)

Tabtab Labs

47

48 of 56

Log

  • Embedding similarity term
    • Easy if same d_model
    • If Student has lower d_model, we introduce projection after the last MambaDNA block

Tabtab Labs

48

49 of 56

Tabtab Labs

49

Debugging Batch 8590

50 of 56

Tabtab Labs

50

51 of 56

Tabtab Labs

51

Thank You!�Questions?

52 of 56

Tabtab Labs

52

Scaling Laws

53 of 56

Tabtab Labs

53

54 of 56

Tabtab Labs

54

55 of 56

Tabtab Labs

55

56 of 56

Tabtab Labs

56