1 of 38

Foundations of Optimisation & Generalisation in Neural Networks

Jeremy Bernstein

https://jeremybernste.in

2 of 38

What makes deep learning theory hard?

A common pattern of study:

  1. Take well-known classical theory;
  2. Try to bend it to match deep learning.�

✅ Results look qualitatively complete.

❌ Results often quantitatively vacuous.

2

3 of 38

What makes deep learning theory hard?

A common pattern of study:

  • Take well-known classical theory;
  • Try to bend it to match deep learning.�

✅ Results look qualitatively complete.

❌ Results often quantitatively vacuous.

❌ Results do not address important practical questions.

3

Something better is out there.

…Why must I tune ten different hyperparameters?

…How can I control my network’s inductive bias?

4 of 38

Neural architecture as the starting point

Our pattern of study:

  1. Perform targeted explorations of the NN function space;
  2. Introduce theoretical tools selectively.�

✅ Results are quantitatively non-vacuous.

✅ Theory makes useful predictions.

❌ Complete theory is still a work-in-progress.

4

5 of 38

Two ideas in this talk

  1. Neural network–Bayes point machine correspondence
  • Neural architecture aware optimisation

5

Bayesian risk bounds for single classifiers.

Theoretical payoff

Practical payoff

Inductive bias control.

Deep network descent lemma.

Theoretical payoff

Practical payoff

Hyperparameter transfer.

6 of 38

Part I

Max-Margin Neural Networks�as Bayes Point Machines

7 of 38

The modern machine learning paradigm

The punchline of this section

  • Interpolating classifiers form an ensemble.�
  • To predict, we can use the majority vote.�
  • NN can fit its own majority vote in one shot.

7

Driving question

Why do neural nets generalise… even when non-generalising functions exist in the function space?

Information theory

Voting theory

Convex geometry

8 of 38

8

Bayes point machine (noun)

A single classifier that approximates the aggregate prediction of an ensemble of classifiers.

Herbrich, Graepel & Campbell (JMLR 2001)

9 of 38

Connections to other disciplines

  • Mean voter theorem in social choice theory When can a single voter speak for an electorate?�📚 Caplin & Nalebuff, 1991
  • Tukey median in high dimensional statisticsHow can we generalise the median to high dimension?�📚 Tukey, 1977
  • Mass partitions in geometryHow can we divide a set into more-or-less equal halves?�📚 Grünbaum, 1960

9

10 of 38

We propose the NN–BPM correspondence

10

input

majority vote

Ensemble of NNs

✅ Good theory.

❌ Expensive to construct.

input

sign

One large margin NN

One large margin NN

❌ Poor theory.

✅ Used in practice.

11 of 38

Part I: NN–BPM correspondence

Introduction

Experiments

Theory

Outlook

12 of 38

Part I: NN–BPM correspondence

Introduction

Experiments

Theory

Outlook

13 of 38

Let’s take control of normalised margin

  • Train ensembles of neural networks to 100% train accuracy.
  • Control the normalised margin of each ensemble member.
  • Measure the test accuracy.

13

x → sign(WL⚬𝜙⚬WL-1⚬…⚬𝜙⚬W1⚬x)

x → sign(wTx)

‖w‖ · ‖x‖

normalised margin =

margin

normalised margin =

margin

Πk‖Wk‖ · ‖x‖

14 of 38

Normalised margin control

14

Normalised margin control carefully scales:

  1. the outputs, by choice of loss function:� La(w) ← ∑(x,y)(f(x;w) - a⋅y)2.
  2. the weights, by projected gradient descent:� Wk ← Wk / ‖WkF for each layer k.
  3. the inputs, by projection:� x ← x / ‖x‖2 for each input x.

W1

W2

W3

W4

input x

output f(x;w)

This yields an interpolator with

normalised margin a.

15 of 38

Effect of normalised margin control

For networks with 100% train accuracy:

15

  • Ensemble size ↑ implies test accuracy ↑.
  • Normalised margin ↑ implies test accuracy ↑.

16 of 38

Part I: NN–BPM correspondence

Introduction

Experiments

Theory

Outlook

17 of 38

Our theoretical results

Result 2. (Generalisation)

Bayes point machines have nice generalisation properties.

Result 1. (Concentration)

Conditioned on fitting a training set, as normalised margin → , the function space of a width- NN concentrates on a BPM.

17

Bayes point machine (noun)

A single classifier that approximates the aggregate prediction of an ensemble of classifiers.

Herbrich, Graepel & Campbell (JMLR 2001)

18 of 38

Result 1. (Concentration)

Conditioned on fitting a training set, as normalised margin → , the function space of a width- NN concentrates on a BPM.

18

Proof sketch.

  1. The function space of a width-∞ NN, conditioned on fitting a training set, is a GP posterior distribution.
  2. Taking normalised margin → ∞ causes this GP posterior distribution to concentrate on its mean.

19 of 38

Result 1. (Concentration)

Conditioned on fitting a training set, as normalised margin → , the function space of a width- NN concentrates on a BPM.

19

Proof sketch.

  • The function space of a width-∞ NN, conditioned on fitting a training set, is a GP posterior distribution.
  • Taking normalised margin → ∞ causes this GP posterior distribution to concentrate on its mean.
  • The mean prediction approximates the majority vote.

Proof sketch.

  • The function space of a width-∞ NN, conditioned on fitting a training set, is a GP posterior distribution.
  • Taking normalised margin → ∞ causes this GP posterior distribution to concentrate on its mean.

weighted Grünbaum’s inequality (Caplin & Nalebuff, 1991)

Prob[halfspace( )] > 1/e

20 of 38

Result 2. (Generalisation)

Bayes point machines have nice generalisation properties.

20

risk(BPM) ≈ risk(majority vote)

≪ 𝔼 risk(randomised predictor)

≲ # bits in Q / # train examples

Given an ensemble Q of classifiers, we can classify x:

  • randomly x ↦ sign f(w;x) for w~Q (aka Gibbs classifier).
  • by majority vote x ↦ sign 𝔼w sign f(w;x) (aka Bayes classifier).

C-bound

(Lacasse et al, 2006)

PAC-Bayes

(McAllester, 1999)

21 of 38

Part I: NN–BPM correspondence

Introduction

Experiments

Theory

Outlook

22 of 38

NN–BPM correspondence

  • A new way to understand generalisation in deep learning.
  • Based on old ideas about kernel methods.
  • What’s better than reading old papers? Reading old textbooks.

22

23 of 38

Controlled optimisation as a scientific tool

  • We are charting the behaviour of a function space that we do not fully understand.
  • Optimisation is our friend.

23

Normalised margin control carefully scales:

  • the outputs, by choice of loss function:� La(w) ← ∑(x,y)(f(x;w) - a⋅y)2.
  • the weights, by projected gradient descent:� Wk ← Wk / ‖WkF for each layer k.
  • the inputs, by projection:� x ← x / ‖x‖2 for each input x.

24 of 38

Application: Uncertainty quantification

  • Let’s train an ensemble of NNs to estimate model uncertainty.
  • We need to avoid the ensemble collapsing on a single function.
  • Normalised margin control can “anti-regularise” the ensemble.

24

24

anti-regularise

regularise

Good generalisation

Good uncertainty

25 of 38

References

25

input

majority vote

input

sign

Ensemble of NNs

One large margin NN

One large margin NN

26 of 38

Part II

Neural Architecture Aware Optimisation

27 of 38

Optimisation theory versus practice

Practice

Perturb every operator in a composition!

Theory

Flatten weights; neglect compositional structure!

27

minΔW gTΔW + λ ‖ΔW‖2

At each iteration:

  1. minimise first order Taylor approximation of the loss;
  2. subject to quadratic penalty.

W1

W2

W3

WL

WL-1

ΔW1

ΔW2

ΔW3

ΔWL

ΔWL-1

first order approximation

quadratic �trust region

28 of 38

Optimisation theory versus practice

New theory

Practice

Perturb every operator in a composition!

‖Δf(x)‖

‖f(x)‖

W1

W2

W3

WL

WL-1

ΔW1

ΔW2

ΔW3

ΔWL

ΔWL-1

Δf(x) := f(W+ΔW; x) - f(W; x)

A weight perturbation induces a functional perturbation:

💡

NN architecture connects the structure of ΔW to Δf:

💡

‖ΔWk

‖Wk

Πk(

)

1+

-1

C ·

[

]

Minimise the loss subject to bounded functional distance:

💡

minΔW gTΔW

s.t. small

‖Δf(x)‖

‖f(x)‖

29 of 38

Deep network descent lemma

Result 1. (Deep network descent lemma)

An update will decrease the objective function provided that �for all layers k = 1, …, L:

quadratic trust

deep relative trust

‖Δf(x)‖

‖f(x)‖

‖ΔWk

‖Wk

Πk(

)

1+

-1

C ·

[

]

  1. direction: ΔWk ∝ -gk;
  2. magnitude: ‖ΔWk‖ < ‖Wk‖ · O(1/L).

add layers

depth�L = 1

depth�L ≫ 1

deep relative trust

30 of 38

Training deeper MLPs with Fromage

  • We explored “Frobenius matched gradient descent” 🧀
  • Enforces ‖ΔWk‖ < ‖Wk‖ · η where η sets the learning rate
  • Similar to LARS (You et al, 2017)

30

  • Fromage trains deeper MLPs than Adam or SGD

increasing depth

  • Deeper nets required smaller learning rate

31 of 38

Hyperparameter transfer

  • Fromage-variant optimiser
  • Tested on 6 deep learning tasks
  • Same hyperparameters worked on 5 out of 6 tasks

31

“may unlock a simpler workflow for training deeper and more complex neural networks”

32 of 38

Connection to neuroscience

Per-synapse relative update

32

grow by factor (1 + η)

shrink by factor (1 - η)

33 of 38

References

33

quadratic trust

deep relative trust

add layers

depth�L = 1

depth�L ≫ 1

34 of 38

What’s next?

35 of 38

Making deep learning a mature technology

  • Learning theory & optimisation theory need to account for the structure of the NN parameter–function map.�
  • Should enable new functionality:
    • Uncertainty quantification;
    • Hyperparameter transfer.�
  • Old theory (used selectively) plays an important role:
    • Bayes point machines;
    • Trust region methods.�
  • More work is needed!

35

36 of 38

Neural hardware

36

“synaptic plasticity”

“weight update”

✅ access to all microscopic details

✅ can do numerical analysis

❌ energy intensive

❌ hard to experiment on

✅ energy efficient

✅ that’s us!

37 of 38

Alex Farhang

Kushal Tirumala

Jiawei Zhao

Yang Liu

Arash Vahdat

Ming-Yu Liu

Anima Anandkumar

Markus Meister

Yisong Yue

38 of 38

Thank you!

  • Deep network descent lemma
  • NN–BPM correspondence

🛝 Slides on my website: https://jeremybernste.in.

Ensemble of NNs

One large margin NN