REPAIR: REnormalizing Permuted Activations for Interpolation Repair
Keller Jordan
keller@kellerjordan.com
With Hanie Sedghi, Olga Saukh, Rahim Entezari, & Behnam Neyshabur
Linear interpolation between weights
[1] Wortsman, Mitchell, et al. "Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time." International Conference on Machine Learning. PMLR, 2022.
[2] Wortsman, Mitchell, et al. "Robust fine-tuning of zero-shot models." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.
What happens when we linearly interpolate between independently trained models?
Why does this happen?
Feature misalignment
Permutation invariance
How do we address feature misalignment?
By leveraging permutation invariance: i.e., we can freely permute features (neurons/channels) in a network, without affecting the function it represents.
Finding a permutation
Li, Yixuan, et al. "Convergent learning: Do different neural networks learn the same representations?." arXiv preprint arXiv:1511.07543 (2015).
Filter alignment
For shallow+wide MLPs, this works
Entezari, Rahim, et al. "The role of permutation invariance in linear mode connectivity of neural networks." arXiv preprint arXiv:2110.06296 (2021).
For deep networks, something breaks down
‘Trick’: just re-estimate BatchNorms!
Why?
Why does this work?
And how can we solve the remaining VGG11 and deep MLP cases?
Network internals
The variance of hidden units/channels apparently collapses.
What is BN reset doing?
BatchNorm refresher:
What is BN reset doing?
Can we generalize the BN reset?
Goal: rescale each channels in our interpolated network such that its mean/std is the midpoint of the mean/stds of the same channel in the two endpoint networks.
REPAIR
Goal: rescale each channels in our interpolated network such that its mean/std is the midpoint of the mean/stds of the same channel in the two endpoint networks.
https://github.com/KellerJordan/REPAIR/blob/master/notebooks/Train-Merge-REPAIR-VGG11.ipynb
REPAIR
REPAIR
REPAIRing residual blocks
Idea: Let’s also use REPAIR to correct the statistics of residual block outputs.
=> Gives another ~15% reduction in the size of the barrier.
A special case: LayerNorm
LayerNorm-based networks are unique in already having a low interpolation barrier without BN reset / REPAIR.
Network width
[1] Entezari, Rahim, et al. "The role of permutation invariance in linear mode connectivity of neural networks." arXiv preprint arXiv:2110.06296 (2021).
Thanks!