Cross-Modal Fine-Tuning:
Align then Refine
Junhong Shen1,2, Liam Li2, Lucio Dery1, Corey Staten2, Mikhail Khodak1,
Graham Neubig1, Ameet Talwalkar1.2
1 Carnegie Mellon University 2 Hewlett Packard Enterprise
Outline
2
Pretrained transformers are changing ML fields
Language, vision, audio
3
Pretrained transformers are changing ML fields
4
Pretrained transformers are changing ML fields
5
What about other modalities?
6
Science: Data-driven PDE solvers
Healthcare: Drug discovery, cancer detection
Manufacturing: Anomaly detection
Finance: Fraud detection, quantitative trading
AdTech: Improved bidding, recommendation
Education: Personalized lesson plans
Software: Automated customer service
Agriculture: Aerial imagery analysis
….
Internet: search
Mobile: Voice ssistants
Transportation: Autonomous vehicle
What about other modalities?
7
Science: Data-driven PDE solvers
Healthcare: Drug discovery, cancer detection
Manufacturing: Anomaly detection
Finance: Fraud detection, quantitative trading
AdTech: Improved bidding, recommendation
Education: Personalized lesson plans
Software: Automated customer service
Agriculture: Aerial imagery analysis
….
Everything Else
— “Diverse Tasks”
Internet: search
Mobile: Assistants
Transportation: Autonomous vehicle
Language, Vision, Audio
Solving diverse tasks in less-studied modalities
8
Challenges:
Existing approaches: neural architecture search
9
Existing approaches: general-purpose models
10
image
point cloud
optical flow
audio
video
classification labels
audiovisual sequences
optical flow fields
Solving diverse tasks in less-studied modalities
11
Challenges:
Require training from scratch
How can existing pretrained models help?
12
Applying models pretrained in data-rich modalities to new problems
How can existing pretrained models help?
13
Applying models pretrained in data-rich modalities to new problems
Cross-Modal Transfer Learning
?
Early evidence: language models can be adapted to…
Recognize images [Kiela et al., 2019] Solve tabular tasks [Dinh et al., 2022]
Play referential games [Li et al., 2020] Reinforcement learning [Reid et al., 2022]
14
But these methods cannot be applied to arbitrary tasks
Many of them are task-specific/ad-hoc
15
A more general method
Frozen Pretrained Transformers (FPT) [Lu et al., 2022]
16
Our goal
General
+
Effective (accounts for modality difference)
=
ORCA
17
ORCA: motivation
Insight: data alignment
18
Notation
Goal: learn a new model mt based on ms to minimize the expected loss on Pt
19
ORCA: workflow
20
Dimensionality alignment ➔ general-purpose
Distribution alignment
➔ task-specific
Stage 1: dimensionality alignment
Decompose a transformer
21
Stage 2: embedder learning for data alignment
If we have a metric d to measure the distance between two distributions
22
Stage 2: embedder learning for data alignment
If we have a metric d to measure the distance between two distributions
23
?
Stage 2: embedder learning for data alignment
Optimal Transport Dataset Distance (OTDD) [Alvarez-Melis & Fusi, 2020]
Pairwise l2, cosine distance, maximum mean discrepancy, optimal-transport, …
24
Stage 3: refine the remaining weights
Compute task loss, back propagate to update
25
Comparison with previous work
26
Empirical results
Two backbones: RoBERTa [Liu et al., 2019], Swin Transformer [Liu et al., 2021]
Reference proxy datasets: CoNLL-2003 (entity recognition), CIFAR-10
27
Can pretrained models transfer across diverse modality?
NAS-Bench-360: 10 tasks with diverse inputs (1D/2D), outputs (point/dense), and modalities (vision, audio, electrocardiogram, PDE, protein, genomics, cosmic-ray)
3 Classes of baselines: hand-designed architectures, general-purpose models, AutoML (NAS) methods
28
…
Wait…which distance metric should we use?
ORCA + different embedder learning metrics on NAS-Bench-360
29
Evaluate aggregate performance
Performance profile [Dolan and More, 2002]
30
Wait…which distance metric should we use?
ORCA + different embedder learning metrics on NAS-Bench-360
31
Can pretrained models transfer across diverse modality?
Yes. ORCA achieves the lowest error on 7/10 tasks
32
ORCA attains the best aggregate performance
ORCA being far in the top left corner indicates it is rarely suboptimal and is often the best
33
Behind the scene
34
Key 1: distribution alignment
35
Key 1: distribution alignment
How does optimizing the OTDD to various levels of convergence affect performance?
36
Key 2: full fine-tuning is better than partial tuning
Recall FPT:
37
Key 2: full fine-tuning is better than partial tuning
38
Key 3: pretraining modality affects performance
39
Application: data-limited regimes
Recall motivation: utilize existing model resources to help regimes where training models from scratch is challenging
40
Application: data-limited regimes
Hypothesis: obtain a good feature embedder can reduce the difficulty of fine-tuning
41
Empirical results
Two backbones: RoBERTa [Liu et al., 2019], Swin Transformer [Liu et al., 2021]
Reference proxy datasets: CoNLL-2003 (entity recognition), CIFAR-10
42
PDEBench for scientific ML
43
PDEBench for scientific ML
44
OpenML Tabular Classification
45
Empirical results
Two backbones: RoBERTa [Liu et al., 2019], Swin Transformer [Liu et al., 2021]
Reference proxy datasets: CoNLL-2003 (entity recognition), CIFAR-10
46
Compare with task-specific cross-modal methods
47
Compare with task-specific cross-modal methods
48
Future directions
49
Discussion: pretrained models for scientific ML
50
Thanks for listening!
51