1 of 23

REPAIR: REnormalizing Permuted Activations for Interpolation Repair

Keller Jordan

keller@kellerjordan.com

With Hanie Sedghi, Olga Saukh, Rahim Entezari, & Behnam Neyshabur

2 of 23

Linear interpolation between weights

  • Has recently been used as a tool to set ImageNet SOTA [1] and improve finetuned-model OOD robustness [2].
  • In both cases, interpolation is between checkpoints which are finetuned from the same initial pretrained weights.

[1] Wortsman, Mitchell, et al. "Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time." International Conference on Machine Learning. PMLR, 2022.

[2] Wortsman, Mitchell, et al. "Robust fine-tuning of zero-shot models." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.

3 of 23

What happens when we linearly interpolate between independently trained models?

4 of 23

Why does this happen?

  1. Feature misalignment (prior and concurrent work)

  • Variance collapse (our work)

5 of 23

Feature misalignment

6 of 23

Permutation invariance

How do we address feature misalignment?

By leveraging permutation invariance: i.e., we can freely permute features (neurons/channels) in a network, without affecting the function it represents.

7 of 23

Finding a permutation

Li, Yixuan, et al. "Convergent learning: Do different neural networks learn the same representations?." arXiv preprint arXiv:1511.07543 (2015).

8 of 23

Filter alignment

9 of 23

For shallow+wide MLPs, this works

Entezari, Rahim, et al. "The role of permutation invariance in linear mode connectivity of neural networks." arXiv preprint arXiv:2110.06296 (2021).

10 of 23

For deep networks, something breaks down

11 of 23

‘Trick’: just re-estimate BatchNorms!

12 of 23

Why?

Why does this work?

And how can we solve the remaining VGG11 and deep MLP cases?

13 of 23

Network internals

The variance of hidden units/channels apparently collapses.

14 of 23

What is BN reset doing?

BatchNorm refresher:

15 of 23

What is BN reset doing?

16 of 23

Can we generalize the BN reset?

Goal: rescale each channels in our interpolated network such that its mean/std is the midpoint of the mean/stds of the same channel in the two endpoint networks.

17 of 23

REPAIR

Goal: rescale each channels in our interpolated network such that its mean/std is the midpoint of the mean/stds of the same channel in the two endpoint networks.

  1. Estimate the mean/std of each channel that we want to renormalize in the two endpoint networks. For a particular channel, call these
  2. Add a BN layer after the same channel in the interpolated network, with weight given by and bias by
  3. Reset the running stats of added BN layers w.r.t. the train distribution.
  4. Optionally, fuse the added BNs back into their preceding layers.

https://github.com/KellerJordan/REPAIR/blob/master/notebooks/Train-Merge-REPAIR-VGG11.ipynb

18 of 23

REPAIR

19 of 23

REPAIR

20 of 23

REPAIRing residual blocks

Idea: Let’s also use REPAIR to correct the statistics of residual block outputs.

=> Gives another ~15% reduction in the size of the barrier.

21 of 23

A special case: LayerNorm

LayerNorm-based networks are unique in already having a low interpolation barrier without BN reset / REPAIR.

22 of 23

Network width

  • As network width increases, the barrier to aligned interpolation has been conjectured to decrease towards zero [1].
  • When using BN reset / REPAIR, this does seem to be the case.

[1] Entezari, Rahim, et al. "The role of permutation invariance in linear mode connectivity of neural networks." arXiv preprint arXiv:2110.06296 (2021).

23 of 23

Thanks!