Foundations of Optimisation & Generalisation in Neural Networks
Jeremy Bernstein
What makes deep learning theory hard?
A common pattern of study:
✅ Results look qualitatively complete.
❌ Results often quantitatively vacuous.
2
What makes deep learning theory hard?
A common pattern of study:
✅ Results look qualitatively complete.
❌ Results often quantitatively vacuous.
❌ Results do not address important practical questions.
3
Something better is out there.
…Why must I tune ten different hyperparameters?
…How can I control my network’s inductive bias?
Neural architecture as the starting point
Our pattern of study:
✅ Results are quantitatively non-vacuous.
✅ Theory makes useful predictions.
❌ Complete theory is still a work-in-progress.
4
Two ideas in this talk
5
Bayesian risk bounds for single classifiers.
Theoretical payoff
Practical payoff
Inductive bias control.
Deep network descent lemma.
Theoretical payoff
Practical payoff
Hyperparameter transfer.
Part I
Max-Margin Neural Networks�as Bayes Point Machines
The modern machine learning paradigm
The punchline of this section
7
Driving question
Why do neural nets generalise… even when non-generalising functions exist in the function space?
Information theory
Voting theory
Convex geometry
8
Bayes point machine (noun)
A single classifier that approximates the aggregate prediction of an ensemble of classifiers.
Herbrich, Graepel & Campbell (JMLR 2001)
Connections to other disciplines
9
We propose the NN–BPM correspondence
≈
10
input
majority vote
Ensemble of NNs
✅ Good theory.
❌ Expensive to construct.
input
sign
One large margin NN
One large margin NN
❌ Poor theory.
✅ Used in practice.
Part I: NN–BPM correspondence
Introduction ←
Experiments
Theory
Outlook
Part I: NN–BPM correspondence
Introduction
Experiments ←
Theory
Outlook
Let’s take control of normalised margin
13
x → sign(WL⚬𝜙⚬WL-1⚬…⚬𝜙⚬W1⚬x)
x → sign(wTx)
‖w‖ · ‖x‖
normalised margin =
margin
normalised margin =
margin
Πk‖Wk‖ · ‖x‖
Normalised margin control
14
Normalised margin control carefully scales:
W1
W2
W3
W4
input x
output f(x;w)
This yields an interpolator with
normalised margin ∝a.
Effect of normalised margin control
For networks with 100% train accuracy:
15
Part I: NN–BPM correspondence
Introduction
Experiments
Theory ←
Outlook
Our theoretical results
Result 2. (Generalisation)
Bayes point machines have nice generalisation properties.
Result 1. (Concentration)
Conditioned on fitting a training set, as normalised margin → ∞, the function space of a width-∞ NN concentrates on a BPM.
17
Bayes point machine (noun)
A single classifier that approximates the aggregate prediction of an ensemble of classifiers.
Herbrich, Graepel & Campbell (JMLR 2001)
→
Result 1. (Concentration)
Conditioned on fitting a training set, as normalised margin → ∞, the function space of a width-∞ NN concentrates on a BPM.
18
Proof sketch.
Result 1. (Concentration)
Conditioned on fitting a training set, as normalised margin → ∞, the function space of a width-∞ NN concentrates on a BPM.
19
Proof sketch.
Proof sketch.
weighted Grünbaum’s inequality (Caplin & Nalebuff, 1991)
Prob[halfspace( )] > 1/e
Result 2. (Generalisation)
Bayes point machines have nice generalisation properties.
20
risk(BPM) ≈ risk(majority vote)
≪ 𝔼 risk(randomised predictor)
≲ # bits in Q / # train examples
Given an ensemble Q of classifiers, we can classify x:
C-bound
(Lacasse et al, 2006)
PAC-Bayes
(McAllester, 1999)
Part I: NN–BPM correspondence
Introduction
Experiments
Theory
Outlook ←
NN–BPM correspondence
22
Controlled optimisation as a scientific tool
23
Normalised margin control carefully scales:
Application: Uncertainty quantification
→
←
24
24
anti-regularise
regularise
Good generalisation
Good uncertainty
References
≈
📚 Kernel interpolation as a Bayes point machine
📚 Investigating generalization by controlling normalized margin
25
input
majority vote
input
sign
Ensemble of NNs
One large margin NN
One large margin NN
Part II
Neural Architecture Aware Optimisation
Optimisation theory versus practice
Practice
Perturb every operator in a composition!
Theory
Flatten weights; neglect compositional structure!
27
minΔW gTΔW + λ ‖ΔW‖2
At each iteration:
W1
W2
W3
WL
WL-1
ΔW1
ΔW2
ΔW3
ΔWL
ΔWL-1
first order approximation
quadratic �trust region
Optimisation theory versus practice
New theory
Practice
Perturb every operator in a composition!
‖Δf(x)‖
‖f(x)‖
≤
W1
W2
W3
WL
WL-1
ΔW1
ΔW2
ΔW3
ΔWL
ΔWL-1
Δf(x) := f(W+ΔW; x) - f(W; x)
A weight perturbation induces a functional perturbation:
💡
NN architecture connects the structure of ΔW to Δf:
💡
‖ΔWk‖
‖Wk‖
Πk(
)
1+
-1
C ·
[
]
Minimise the loss subject to bounded functional distance:
💡
minΔW gTΔW
s.t. small
‖Δf(x)‖
‖f(x)‖
Deep network descent lemma
→
Result 1. (Deep network descent lemma)
An update will decrease the objective function provided that �for all layers k = 1, …, L:
quadratic trust
deep relative trust
‖Δf(x)‖
‖f(x)‖
≤
‖ΔWk‖
‖Wk‖
Πk(
)
1+
-1
C ·
[
]
add layers
depth�L = 1
depth�L ≫ 1
deep relative trust
Training deeper MLPs with Fromage
30
increasing depth
Hyperparameter transfer
31
“may unlock a simpler workflow for training deeper and more complex neural networks”
Connection to neuroscience
Per-synapse relative update
32
grow by factor (1 + η)
shrink by factor (1 - η)
References
→
📚 On the distance between two neural networks and the stability of learning
📚 Learning by turning: Neural architecture aware optimisation
📚 Learning compositional functions via multiplicative weight updates
33
quadratic trust
deep relative trust
add layers
depth�L = 1
depth�L ≫ 1
What’s next?
Making deep learning a mature technology
35
Neural hardware
36
“synaptic plasticity”
“weight update”
✅ access to all microscopic details
✅ can do numerical analysis
❌ energy intensive
❌ hard to experiment on
✅ energy efficient
✅ that’s us!
Alex Farhang
Kushal Tirumala
Jiawei Zhao
Yang Liu
Arash Vahdat
Ming-Yu Liu
Anima Anandkumar
Markus Meister
Yisong Yue
≈
→
Thank you!
🛝 Slides on my website: https://jeremybernste.in.
Ensemble of NNs
One large margin NN