Understanding Deep-Learning as a Physicist: What would Einstein do?�
(ICML Workshop AI@Scale, 07/22/2024)
Yuhai Tu
IBM T. J. Watson Research Center, NY
“Everything should be made as simple as possible, but not simpler.” –- attributed to Einstein
1905 – Einstein’s miracle year
Particle-Wave Duality
The Einstein relation (Fluctuation-Dissipation Theorem)
Universal constant speed of light c
E=mc2
Stochastic Learning Dynamics, Activity-Weight Duality, and Generalization in Neural Nets
Model
(e.g., MLP; LLM)
Training
Data
Test
Data
The “Central Dogma” of Machine Learning
1) Dynamics of learning: How does the system find solutions? Which solution does it find? -- Topic I
2) Generalization: What makes a solution generalizable (good)? How do we regularize the system to find them? – Topic II
Two key sets of questions in DL
Find solutions with better generalization
by imposing constraints
Regularization
Generalization
Verify the learned model
by using test data
Optimization
“Learn” a model (solution)
by fitting training data
Ab initio Protein Folding
(over-constraint problem)
The native protein structure: a unique solution
Hard problem: How to find THE solution?
Energy landscape
Deep Learning
(under-constraint problem)
Many solutions (relatively easy to find)
Hard problem: Which solution is more generalizable?
Loss landscape
On landscape and solutions (minima) in complex systems: Physics versus ML
The “smart” SGD noise and the inverse variance-flatness relation
Topic I. Stochastic learning dynamics in DL: Descending down a fluctuating landscape
Historical background I: the McCullogh-Pitts model for artificial neurons
Neurons can process information. Each neuron is an input-output device. They receive and provide information in form of spikes.
Inputs
Output
w2
w1
w3
wn
wn-1
.
.
.
x1
x2
x3
…
xn-1
xn
y
b
”A Logical calculus of the ideas immanent in nervous activity”
Warren S. McCulloch and Walter H. Pitts
(Bulletin of Mathematical Biophysics 5, 115-133, 1943)
0
1
b
z
0.5
H(z)
Historical background II: The Rosenblatt perceptron model
Inputs
Output
Hidden layer(s)
The simplest perceptron has no hidden layer
In a 1958 press conference organized by the US Navy, Rosenblatt made statements about the perceptron that caused a heated controversy among the fledgling AI community; based on Rosenblatt's statements, The New York Times reported the perceptron to be "the embryo of an electronic computer that [the Navy] expects will be able to walk, talk, see, write, reproduce itself and be conscious of its existence."
What is new this time for deep learning to make it so powerful?
The feedforward neural network was invented more than 60 years ago….
Input
Output
A simple supervised learning problem: Classification of handwritten digits
# of hidden layers
L=2-5
(100’s for complex problems)
784(=28x28) pixels
(M=60,000 images)
This “horse” has been beaten to death, many times over ……
weights
Correct classification
(k=1,2,…,M)
MNIST dataset
We want to find out how the horse died.
What is machine learning for a statistical physicist?
The predicted output depends on the input and weights (parameters):
Machine learning is to find weights that minimize the loss function:
The Gradient Descent (GD) learning algorithm:
Two problems with GD: 1) Too clumsy for large training dataset (M>>1);
2) Stuck in local minima
positive-definite distance measure,
e.g., cross-entropy loss
SGD-based Learning as a stochastic dynamical system
Minibatch and Stochastic Gradient Descent (SGD)
The corresponding Langevin equation for SGD
The “unusual” noise in the SGD Langevin equation
Boltzmann distribution
(Chaudhari & Soatto, 2017)
solutions (weights)
for different
minibatches
weight-space
Starting weights
fast learning
exploration
An intuitive picture of the SGD-based learning dynamics
High dimensional Drift-Diffusion with funny non-homogeneous anisotropic noise
Low-dimensional
smart
PCA analysis in the exploration phase
window =10 epoch
(12000 points)
i
SGD drives a drift-diffusion motion in low-dimension
Two hidden layers (H=50)
🡪2500 parameters in each layer
The effective dimension (D99) is much smaller than the
dimension of the weight space
99% variance is in the first 40 components out of N=50x50=2500 degrees of freedom
40<<2500 !
The Loss landscape in each PC direction
flatness
ln(2)
The inverse Variance – Flatness relation
In SGD, we find an inverse variance-flatness relation (“Inverse Einstein relation”):
Breakdown of fluctuation-dissipation theorem (relation)!
The inverse V-F relation in SGD is the opposite to
the fluctuation-response relation in equilibrium systems
The Einstein relation for
Brownian motion
mobility
(response)
diffusion constant
(fluctuation)
thermal
noise
noise with a
constant strength
Origin of the inverse variance-flatness relation: the flatness-dependent SGD noise
The mini-loss-function (MLF) ensemble
Diffusion constant
velocity correlation time
Variance
The fluctuating landscape
The spatial-temporal structure of the SGD noise
The search intensity in a direction-i can
be described by an effective temperature
sharper
flatter
Smart-noise in SGD: a self-tunned annealing algorithm to find flat minima
Ning Yang, C. Tang, YT, Physical Review Letters, 130, 2023.
The anisotropic landscape-dependent SGD noise introduce an effective loss that favors the flat minima
Implicit regularizer
-- SGD drives a low-dimensional drift-diffusion motion in weight space
-- The inverse relation between weight fluctuation and landscape flatness.
“The inverse Einstein relation”
-- The landscape-dependent anisotropic SGD noise serves as a self-tunned
annealing algorithm for finding flat minima
-- The SGD noise introduces an effective flatness-dependent loss as an implicit regularization
Recap for topic I
Two co-acting geometric determinants for generalization
obtained by an exact activity-weight duality
Topic II. What makes a solution “better” (more generalizable)?
The generalization gap measures generalization of a solution
Training loss
weight space
(high-dimensional)
A solution
Testing loss
weight space
(high-dimensional)
Loss
The Loss equivalence and data-parameter duality
Dual-weight:
The minimal A-W duality
The exact activity-weight dualities in a fully connected layer (FCL)
The duality condition
The activity-weight (A-W) duality for equal loss
feed forward
Computing the generalization gap with the A-W duality
-- the minimal dual weight for sample k
The effective gradient
weight space
(high-dimensional)
Loss
The generalization gap for sample-k
The test and training data are
from the same distribution
The two co-acting determinants for generalization
Data-dependent
distance metric
Hessian eigen-values
Strategy I: Increasing learning rate (or decreasing batch size) in SGD
improves generalization by finding flatter solutions
Strategy II: Weight decay improves generalization by finding smaller solutions
Recap for topic II
2) The activity-weight (AW) duality: a direct connection between neuron activity (data, input) and synaptic weights (parameters) of neural network models.
Maybe helpful to understand: 1) relation between sloppy spectra in data and the flat directions in the loss landscape;
2) Effects of noise in data and adversarial attack.
(Jiang et al, ICLR 2022)
1) The two co-acting determinants for generalization: flatness of the loss landscape and size of the solution.
I. Flatness is a strong indicator for generalization. Useful to develop better algorithms for enhancing generalization.
II. But, flatness is not the only generalization determinant -- it has to work together with the size of the solution to determine generalization.
(Dinh et al, ICML 2017)
Take-home messages
Is it caused by the structure/size of the data or structure of the network or both?
Does it have anything to do with scaling in critical phenomena?
Can we understand it with statistical physics techniques like Renormalization Group (RG) Theory?
What characterize the learning process/dynamics of attention?
How does in-context learning happen?
What are the more efficient learning algorithm for LLM? and Why? Adam versus SGD (with momentum)
What determines generalization in transformer-based models?
Is the notion of loss landscape useful?
How does grokking happen?
Some next steps in using statistical physics approach for understanding ML/AI?