1 of 51

Cross-Modal Fine-Tuning:

Align then Refine

Junhong Shen1,2, Liam Li2, Lucio Dery1, Corey Staten2, Mikhail Khodak1,

Graham Neubig1, Ameet Talwalkar1.2

1 Carnegie Mellon University 2 Hewlett Packard Enterprise

2 of 51

Outline

  • Motivation: why do we study cross-modal transfer?
  • ORCA: align-then-refine workflow
  • Empirical results
  • Discussion & future directions

2

3 of 51

Pretrained transformers are changing ML fields

Language, vision, audio

3

4 of 51

Pretrained transformers are changing ML fields

4

5 of 51

Pretrained transformers are changing ML fields

5

6 of 51

What about other modalities?

6

Science: Data-driven PDE solvers

Healthcare: Drug discovery, cancer detection

Manufacturing: Anomaly detection

Finance: Fraud detection, quantitative trading

AdTech: Improved bidding, recommendation

Education: Personalized lesson plans

Software: Automated customer service

Agriculture: Aerial imagery analysis

….

Internet: search

Mobile: Voice ssistants

Transportation: Autonomous vehicle

7 of 51

What about other modalities?

7

Science: Data-driven PDE solvers

Healthcare: Drug discovery, cancer detection

Manufacturing: Anomaly detection

Finance: Fraud detection, quantitative trading

AdTech: Improved bidding, recommendation

Education: Personalized lesson plans

Software: Automated customer service

Agriculture: Aerial imagery analysis

….

Everything Else

— “Diverse Tasks”

Internet: search

Mobile: Assistants

Transportation: Autonomous vehicle

Language, Vision, Audio

8 of 51

Solving diverse tasks in less-studied modalities

  • Design specialized networks ➔ domain knowledge & ML expertise
    • AutoML (NAS)
    • General-purpose architectures

8

Challenges:

9 of 51

Existing approaches: neural architecture search

  • Search for task-specific networks from a predefined search space and then train the network, e.g., XD [Roberts et al., 2021], DASH [Shen et al., 2022]

9

10 of 51

Existing approaches: general-purpose models

  • Provide a general architecture that can be trained with a variety of data formats, e.g., Perceiver and Perceiver IO [Jaegle et al., 2021]

10

image

point cloud

optical flow

audio

video

classification labels

audiovisual sequences

optical flow fields

11 of 51

Solving diverse tasks in less-studied modalities

  • Design specialized networks ➔ domain knowledge & ML expertise
    • AutoML (NAS)
    • General-purpose architectures
  • Limited label data

11

Challenges:

Require training from scratch

12 of 51

How can existing pretrained models help?

  • No need for training from scratch
  • Alleviate modeling & data concerns
  • Reduce the human effort needed to develop high-quality task-specific models

12

Applying models pretrained in data-rich modalities to new problems

13 of 51

How can existing pretrained models help?

  • No need for training from scratch
  • Alleviate modeling & data concerns
  • Reduce the human effort needed to develop high-quality task-specific models

13

Applying models pretrained in data-rich modalities to new problems

Cross-Modal Transfer Learning

?

14 of 51

Early evidence: language models can be adapted to…

Recognize images [Kiela et al., 2019] Solve tabular tasks [Dinh et al., 2022]

Play referential games [Li et al., 2020] Reinforcement learning [Reid et al., 2022]

14

15 of 51

But these methods cannot be applied to arbitrary tasks

Many of them are task-specific/ad-hoc

  • Hand-crafted architecture components, manual prompt engineering

15

16 of 51

A more general method

Frozen Pretrained Transformers (FPT) [Lu et al., 2022]

  • Image/protein sequence/math symbol ➔ sequence features
  • Fine-tune layernorms ➔ prevent overfitting
  • No task/semantic-specific component ➔ not effective enough

16

17 of 51

Our goal

General

+

Effective (accounts for modality difference)

=

ORCA

17

18 of 51

ORCA: motivation

  • Different data have distinct patterns
  • Given a pretrained model, we know on what datasets it performs well, can we manipulate the target data into something similar to those datasets?

Insight: data alignment

  • Dimensionality
  • Distribution

18

19 of 51

Notation

  • Target data {xt , yt } ~ Pt in target domain Dt
  • Transformer ms in source domain Ds
  • Reference data {xs , ys } ~ Ps in source domain Ds

Goal: learn a new model mt based on ms to minimize the expected loss on Pt

19

20 of 51

ORCA: workflow

20

Dimensionality alignment ➔ general-purpose

Distribution alignment

➔ task-specific

21 of 51

Stage 1: dimensionality alignment

Decompose a transformer

  • Embedder f t : raw input ➔ sequence features
    • RcxhxwRseq_len x embed_dim
  • Model body gt: sequence ➔ sequence
    • Rseq_len x embed_dim Rseq_len x embed_dim
  • Prediction head ht: sequence features ➔ output
    • Classification: 1D adaptive pooling + linear
    • Dense prediction: reshape + adaptive pooling

21

22 of 51

Stage 2: embedder learning for data alignment

If we have a metric d to measure the distance between two distributions

  • Minimize d to align the embedded feature distribution of the target dataset with that of the reference dataset

22

23 of 51

Stage 2: embedder learning for data alignment

If we have a metric d to measure the distance between two distributions

  • Minimize d to align the embedded feature distribution of the target dataset with that of the reference dataset

23

?

24 of 51

Stage 2: embedder learning for data alignment

Optimal Transport Dataset Distance (OTDD) [Alvarez-Melis & Fusi, 2020]

  • Use label information: model each class as a distribution over features

Pairwise l2, cosine distance, maximum mean discrepancy, optimal-transport, …

24

25 of 51

Stage 3: refine the remaining weights

Compute task loss, back propagate to update

  • embedder
  • model body
  • prediction head

25

26 of 51

Comparison with previous work

26

27 of 51

Empirical results

Two backbones: RoBERTa [Liu et al., 2019], Swin Transformer [Liu et al., 2021]

Reference proxy datasets: CoNLL-2003 (entity recognition), CIFAR-10

  • A breadth of modalities : NAS-Bench-360
  • Two benchmarks that we analyze in depth
    • PDEBench
    • OpenML-CC18 Tabular Classification
  • Compare with existing cross-modal work on drug response & tabular prediction

27

28 of 51

Can pretrained models transfer across diverse modality?

NAS-Bench-360: 10 tasks with diverse inputs (1D/2D), outputs (point/dense), and modalities (vision, audio, electrocardiogram, PDE, protein, genomics, cosmic-ray)

3 Classes of baselines: hand-designed architectures, general-purpose models, AutoML (NAS) methods

28

29 of 51

Wait…which distance metric should we use?

ORCA + different embedder learning metrics on NAS-Bench-360

29

30 of 51

Evaluate aggregate performance

Performance profile [Dolan and More, 2002]

  • Normalize metrics for different tasks (0-1 error, MAP, F1 score, MAE, …)
  • Compute fraction of tasks where a method has distance from optimality less than τ

30

31 of 51

Wait…which distance metric should we use?

ORCA + different embedder learning metrics on NAS-Bench-360

31

32 of 51

Can pretrained models transfer across diverse modality?

Yes. ORCA achieves the lowest error on 7/10 tasks

  • Outperforms/matches hand-designed architectures on all tasks

  • Outperforms all AutoML baselines on 8/10 tasks
    • 2nd DeepSEA, 3rd onNinaPro

  • Improvements come at a small computational cost
    • Embedder learning time is a small portion (11%) of the fine-tuning time

32

33 of 51

ORCA attains the best aggregate performance

ORCA being far in the top left corner indicates it is rarely suboptimal and is often the best

33

34 of 51

Behind the scene

  • Distribution alignment is the key
  • Full fine-tuning is better than partial fine-tuning
  • Suitable pretrained model selection is important

34

35 of 51

Key 1: distribution alignment

  • ORCA consistently outperforms naive fine-tuning ➔ we need alignment
  • Train-from-scratch is worse than ORCA but better than fine-tuning on ECG, Satellite, and DeepSEA ➔ naive fine-tuning without alignment may even hurt performance

35

36 of 51

Key 1: distribution alignment

How does optimizing the OTDD to various levels of convergence affect performance?

36

37 of 51

Key 2: full fine-tuning is better than partial tuning

Recall FPT:

  1. Does not pretrain the embedding layers for task-specific adaptation
  2. Only fine-tunes the layer norms

37

38 of 51

Key 2: full fine-tuning is better than partial tuning

  • ORCA-layernorm outperforms FPT, but performance gain is smaller than full fine-tuning
    • Data alignment boosts the cross-modal performance of FPT
    • Full fine-tuning takes better advantage of the aligned embeddings
  • Fine-tuning just the layer norms only results in less than 2x speedups

38

39 of 51

Key 3: pretraining modality affects performance

  • To select pretrained models from a predefined model hub for each task: compare the optimized distribution distance and pick the one with the smallest value

39

40 of 51

Application: data-limited regimes

Recall motivation: utilize existing model resources to help regimes where training models from scratch is challenging

40

41 of 51

Application: data-limited regimes

Hypothesis: obtain a good feature embedder can reduce the difficulty of fine-tuning

  • Fine-tuning suffer from limited data, but ORCA can considerably alleviate the problem
  • ORCA allows us to match the performance of standard fine-tuning with 3x the amount of data

41

42 of 51

Empirical results

Two backbones: RoBERTa [Liu et al., 2019], Swin Transformer [Liu et al., 2021]

Reference proxy datasets: CoNLL-2003 (entity recognition), CIFAR-10

  • A breadth of modalities : NAS-Bench-360
  • Two benchmarks that we analyze in depth
    • PDEBench
    • OpenML-CC18 Tabular Classification
  • Compare with existing cross-modal work on drug response & tabular prediction

42

43 of 51

PDEBench for scientific ML

  • ORCA outperforms PINN and U-Net on all evaluated datasets and beats FNO on half

43

44 of 51

PDEBench for scientific ML

  • ORCA achieves zero-shot super-resolution when using RoBERTa + pointwise conv embedder

44

45 of 51

OpenML Tabular Classification

  • ORCA ranks 1st on 12/30 tasks, matches the performance of AutoGluon

45

46 of 51

Empirical results

Two backbones: RoBERTa [Liu et al., 2019], Swin Transformer [Liu et al., 2021]

Reference proxy datasets: CoNLL-2003 (entity recognition), CIFAR-10

  • A breadth of modalities : NAS-Bench-360
  • Two benchmarks that we analyze in depth
    • PDEBench
    • OpenML-CC18 Tabular Classification
  • Compare with existing cross-modal work on drug response & tabular prediction

46

47 of 51

Compare with task-specific cross-modal methods

  • IGTD [Zhu et al., 2021]: convert tabular drug-gene features into pixel representation & apply CNN to predict drug response

47

48 of 51

Compare with task-specific cross-modal methods

  • LIFT [Dinh et al., 2022] transforms tabular data into text to prompt a pretrained GPT-3

48

49 of 51

Future directions

  • Why ORCA works?
  • High-dimensional problems and reinforcement learning
  • ORCA 2.0: automating architecture & weight selection

49

50 of 51

Discussion: pretrained models for scientific ML

  • Develop modality-specific ones
    • Architecture?
    • Annotated data?
    • Unsupervised learning scheme?
    • Synthetic data?

  • Exploit resources in other modalities
    • Models (e.g. LLMs for protein sequence representation learning [Vinod et al., 2023])
    • Data

50

51 of 51

Thanks for listening!

51