two analyses of gradient-based optimization for wide two ... · two analyses of gradient-based...
TRANSCRIPT
![Page 1: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/1.jpg)
Two analyses of gradient-based optimization
for wide two-layer neural networks
Lenaıc Chizat*, joint work with Francis Bach+ and Edouard Oyallon§
Nov. 5th 2019 - Theory of Neural Networks Seminar - EPFL
∗CNRS and Universite Paris-Sud +INRIA and ENS Paris §Centrale Paris
![Page 2: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/2.jpg)
Introduction
![Page 3: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/3.jpg)
Setting
Supervised machine learning
• given input/output training data (x (1), y (1)), . . . , (x (n), y (n))
• build a function f such that f (x) ≈ y for unseen data (x , y)
Gradient-based learning paradigm
• choose a parametric class of functions f (w , ·) : x 7→ f (w , x)
• a convex loss ` to compare outputs: squared/logistic/hinge...
• starting from some w0, update parameters using gradients
Example: Stochastic Gradient Descent with step-sizes (η(k))k≥1
w (k) = w (k−1) − η(k)∇w [`(f (w (k−1), x (k)), y (k))]
[Refs]:
Robbins, Monroe (1951). A Stochastic Approximation Method.
LeCun, Bottou, Bengio, Haffner (1998). Gradient-Based Learning Applied to Document Recognition.
1/29
![Page 4: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/4.jpg)
Models
Linear in the parameters
Linear regression, prior/random features, kernel methods:
f (w , x) = w · φ(x)
• convex optimization
Neural networks
Vanilla NN with activation σ & parameters (W1, b1), . . . , (WL, bL):
f (w , x) = W TL σ(W T
L−1σ(. . . σ(W T1 x + b1) . . . ) + bL−1) + bL
• interacting & interchangeable units/filters
• compositional
2/29
![Page 5: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/5.jpg)
Models
Linear in the parameters
Linear regression, prior/random features, kernel methods:
f (w , x) = w · φ(x)
• convex optimization
Neural networks
Vanilla NN with activation σ & parameters (W1, b1), . . . , (WL, bL):
f (w , x) = W TL σ(W T
L−1σ(. . . σ(W T1 x + b1) . . . ) + bL−1) + bL
• interacting & interchangeable units/filters
• compositional
2/29
![Page 6: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/6.jpg)
Wide two-layer neural networks
![Page 7: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/7.jpg)
Two-layer neural networks
x [1]
x [2]
y
Hidden layerInput layer Output layer
• With activation σ, define φ(wi , x) = ciσ(ai · x + bi ) and
fm(w, x) =1
m
m∑i=1
φ(wi , x) with wi = (ai , bi , ci ) ∈ Rp
• Estimate the parameters w = (w1, . . . ,wm) by solving
minw
Fm(w) := R(fm(w, ·))︸ ︷︷ ︸Empirical/population risk
+ Gm(w)︸ ︷︷ ︸Regularization
4/29
![Page 8: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/8.jpg)
Infinitely wide two-layer networks
• Parameterize the predictor with a probability µ ∈ P(Rp)
f (µ, x) =
∫Rp
φ(w , x)dµ(w)
• Estimate the measure µ by solving
minµ
F (µ) = R(f (µ, ·))︸ ︷︷ ︸Empirical/population risk
+ G (µ)︸ ︷︷ ︸Regularization
• lifted version of “convex” neural networks
[Refs]:
Bengio et al. (2006). Convex neural networks.
5/29
![Page 9: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/9.jpg)
Adaptivity of neural networks
Goal: Estimate a 1-Lipschitz function y : Rd → R given n iid
samples from ρ ∈ P(Rd). Error bound on∫
(f (x)− y(x))2dρ(x) ?
• Ω(n−1/d) (curse of dimensionality)
Same question, if moreover y(x) = g(Ax) for some A ∈ Rs×d?
• O(n−1/d) for kernel methods (some lower bounds too)
• O(d1/2n−1/(s+3)) for 2-layer ReLU networks with weight decay
obtained with G (µ) =∫Vdµ with V (w) = ‖a‖2 + |b|2 + |c |2
no a priori bound on the number m of units required
connecting theory and practice:
Is it related to the predictor learnt by gradient descent?
[Refs]:
Barron (1993). Approximation and estimation bounds for artificial neural networks.
Bach. (2014). Breaking the curse of dimensionality with convex neural networks.
6/29
![Page 10: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/10.jpg)
Adaptivity of neural networks
Goal: Estimate a 1-Lipschitz function y : Rd → R given n iid
samples from ρ ∈ P(Rd). Error bound on∫
(f (x)− y(x))2dρ(x) ?
• Ω(n−1/d) (curse of dimensionality)
Same question, if moreover y(x) = g(Ax) for some A ∈ Rs×d?
• O(n−1/d) for kernel methods (some lower bounds too)
• O(d1/2n−1/(s+3)) for 2-layer ReLU networks with weight decay
obtained with G (µ) =∫Vdµ with V (w) = ‖a‖2 + |b|2 + |c |2
no a priori bound on the number m of units required
connecting theory and practice:
Is it related to the predictor learnt by gradient descent?
[Refs]:
Barron (1993). Approximation and estimation bounds for artificial neural networks.
Bach. (2014). Breaking the curse of dimensionality with convex neural networks.
6/29
![Page 11: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/11.jpg)
Mean-field dynamic and global
convergence
![Page 12: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/12.jpg)
Continuous time dynamics
Gradient flow
Initialize w(0) = (w1(0), . . . ,wm(0)).
Small step-size limit of (stochastic) gradient descent:
w(t + η) = w(t)− η∇Fm(w(t)) ⇒η→0
d
dtw(t) = −m∇Fm(w(t))
Measure representation
Corresponding dynamics in the space of probabilities P(Rp):
µt,m =1
m
m∑i=1
δwi (t)
Technical note: in what follows P2(Rp) is the Wasserstein space
7/29
![Page 13: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/13.jpg)
Many-particle / mean-field limit
Theorem
Assume that w1(0),w2(0), . . . are such that µ0,m → µ0 in P2(Rp)
and some regularity. Then µt,m → µt in P2(Rp), uniformly on
[0,T ], where µt is the unique Wasserstein gradient flow of F
starting from µ0.
Wasserstein gradient flows are characterized by
∂tµt = −div(−∇F ′µtµt)
where F ′µ ∈ C1(Rp) is the Frechet derivative of F at µ.
[Refs]:
Nitanda, Suzuki (2017). Stochastic particle gradient descent for infinite ensembles.
Mei, Montanari, Nguyen (2018). A Mean Field View of the Landscape of Two-Layers Neural Networks.
Rotskoff, Vanden-Eijndem (2018). Parameters as Interacting Particles [...].
Sirignano, Spiliopoulos (2018). Mean Field Analysis of Neural Networks.
Chizat, Bach (2018). On the Global Convergence of Gradient Descent for Over-parameterized Models [...]
8/29
![Page 14: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/14.jpg)
Global convergence (Chizat & Bach 2018)
Theorem (2-homogeneous case)
Assume that φ is positively 2-homogeneous and some regularity. If
the support of µ0 covers all directions (e.g. Gaussian) and if
µt → µ∞ in P2(Rp), then µ∞ is a global minimizer of F .
Non-convex landscape : initialization matters
Corollary
Under the same assumptions, if at initialization µ0,m → µ0 then
limt→∞
limm→∞
F (µm,t) = limm→∞
limt→∞
F (µm,t) = inf F .
Generalization properties, if F is ...
• the regularized empirical risk: statistical adaptivity !
• the population risk: need convergence speed (?)
• the unregularized empirical risk: need implicit bias (?)
[Refs]:
Chizat, Bach (2018). On the Global Convergence of Gradient Descent for Over-parameterized Models [...].
9/29
![Page 15: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/15.jpg)
Global convergence (Chizat & Bach 2018)
Theorem (2-homogeneous case)
Assume that φ is positively 2-homogeneous and some regularity. If
the support of µ0 covers all directions (e.g. Gaussian) and if
µt → µ∞ in P2(Rp), then µ∞ is a global minimizer of F .
Non-convex landscape : initialization matters
Corollary
Under the same assumptions, if at initialization µ0,m → µ0 then
limt→∞
limm→∞
F (µm,t) = limm→∞
limt→∞
F (µm,t) = inf F .
Generalization properties, if F is ...
• the regularized empirical risk: statistical adaptivity !
• the population risk: need convergence speed (?)
• the unregularized empirical risk: need implicit bias (?)
[Refs]:
Chizat, Bach (2018). On the Global Convergence of Gradient Descent for Over-parameterized Models [...]. 9/29
![Page 16: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/16.jpg)
Numerical Illustrations
ReLU, d = 2, optimal predictor has 5 neurons (population risk)
2 1 0 1 2 3
2
1
0
1
2
particle gradient flowoptimal positionslimit measure
5 neurons
2 1 0 1 2 3
2
1
0
1
2
10 neurons
2 1 0 1 2 3
2
1
0
1
2
100 neurons
2 1 0 1 2 3
2
1
0
1
2
10/29
![Page 17: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/17.jpg)
Performance
Population risk at convergence vs m
ReLU, d = 100, optimal predictor has 20 neurons
101 102
10 6
10 5
10 4
10 3
10 2
10 1
100
particle gradient flowconvex minimizationbelow optim. errorm0
11/29
![Page 18: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/18.jpg)
Convex optimization on measures
Sparse deconvolution on T2 with Dirichlet kernel
(white) sources (red) particles.
Computational guaranties for the regularized case, but m
exponential in the dimension d
[Refs]:
Chizat (2019). Sparse Optimization on Measures with Over-parameterized Gradient Descent. 12/29
![Page 19: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/19.jpg)
Lazy Training
![Page 20: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/20.jpg)
Neural Tangent Kernel (Jacot et al. 2018)
Infinite width limit of standard neural networks
For infinitely wide fully connected neural networks of any depth
with “standard” initialization and no regularization: the gradient
flow implicitly performs kernel ridge(less) regression with the
neural tangent kernel
K (x , x ′) = limm→∞
〈∇w fm(w0, x),∇w fm(w0, x′)〉.
Reconciling the two views:
fm(w , x) =1√m
m∑i=1
φ(wi , x) vs. fm(w , x) =1
m
m∑i=1
φ(wi , x)
This behavior is not intrinsically due to over-parameterization
but to an exploding scale
[Refs]:
Jacot, Gabriel, Hongler (2018). Neural Tangent Kernel: Convergence and Generalization in Neural Networks.
20/29
![Page 21: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/21.jpg)
Neural Tangent Kernel (Jacot et al. 2018)
Infinite width limit of standard neural networks
For infinitely wide fully connected neural networks of any depth
with “standard” initialization and no regularization: the gradient
flow implicitly performs kernel ridge(less) regression with the
neural tangent kernel
K (x , x ′) = limm→∞
〈∇w fm(w0, x),∇w fm(w0, x′)〉.
Reconciling the two views:
fm(w , x) =1√m
m∑i=1
φ(wi , x) vs. fm(w , x) =1
m
m∑i=1
φ(wi , x)
This behavior is not intrinsically due to over-parameterization
but to an exploding scale
[Refs]:
Jacot, Gabriel, Hongler (2018). Neural Tangent Kernel: Convergence and Generalization in Neural Networks.20/29
![Page 22: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/22.jpg)
Linearized model and scale
• let h(w) = f (w , ·) be a differentiable model
• let h(w) = h(w0) + Dhw0(w − w0) be its linearization at w0
Compare 2 training trajectories starting from w0, with scale α > 0:
• wα(t) gradient flow of Fα(w) = R(αh(w))/α2
• wα(t) gradient flow of Fα(w) = R(αh(w))/α2
if h(w0) ≈ 0 and α large, then wα(t) ≈ wα(t)21/29
![Page 23: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/23.jpg)
Lazy training theorems
Theorem (Non-asymptotic)
If h(w0) = 0, and R potentially non-convex, for any T > 0, it holds
limα→∞
supt∈[0,T ]
‖αh(wα(t))− αh(wα(t))‖ = 0
Theorem (Strongly convex)
If h(w0) = 0, and R strongly convex, it holds
limα→∞
supt≥0‖αh(wα(t))− αh(wα(t))‖ = 0
• instance of implicit bias: lazy because parameters hardly move
• may replace the model by its linearization
[Refs]:
Chizat, Oyallon, Bach (2018). On Lazy Training in Differentiable Programming.
22/29
![Page 24: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/24.jpg)
When does lazy training occur (without α)?
Relative scale criterion
For R(y) = 12‖y − y?‖2, relative error at normalized time t is
err . t2κh(w0) where κh(w0) :=‖h(w0)− y?‖‖∇h(w0)‖
‖∇2h(w0)‖‖∇h(w0)‖
Examples (h(w) = f (w , ·)):
• Homogeneous models with f (w0, ·) = 0.
If for λ > 0, f (λw , x) = λLf (w , x), then κf (w0) 1/‖w0‖L
• Wide two-layer NNs with iid weights, EΦ(wi , ·) = 0.
If f (w , x) = α(m)∑m
i=1 Φ(wi , x), then κf (w0) (mα(m))−1
24/29
![Page 25: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/25.jpg)
Numerical Illustrations
Training paths (ReLU, d = 2, optimal predictor has m = 3 neurons)
circle of radius 1gradient flow (+)gradient flow (-)
(a) Lazy (b) Not lazy
26/29
![Page 26: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/26.jpg)
Performance
Teacher-student setting, generalization in 100-d vs init. scale τ :
10 2 10 1 100 101
0.0
0.5
1.0
1.5
2.0
2.5
3.0
3.5
4.0
Test
loss
end of trainingbest throughout training
Perf. of ConvNets (VGG11) for image classification (CIFAR10):
101 103 105 107
(scale of the model)
60
70
80
90
100
%
train accuracytest accuracystability of activations
Figure 3: VGG-11 on CIFAR10
• similar gaps observed for widened
ConvNets & ResNets
• CKNa and taylored NTKb perform
well on this (not so hard) task
aConvolutional Kernel NetworksbNeural Tangent Kernel
27/29
![Page 27: Two analyses of gradient-based optimization for wide two ... · Two analyses of gradient-based optimization for wide two-layer neural networks L ena c Chizat*, joint work with Francis](https://reader034.vdocuments.net/reader034/viewer/2022051810/601a1ca54090e016336d62c1/html5/thumbnails/27.jpg)
Conclusion
• Gradient descent on infinitely wide 2-layer networks converges
to global minimizers
• Generalization behavior depends on initialization, loss,
stopping time, signal scale, regularization...
• For the regularized empirical risk, it breaks the statistical
curse of dimensionality
• But not (yet) the computational curse of dimensionality
[Refs]:
- Chizat, Bach (2018). On the Global Convergence of Over-parameterized Models using Optimal Transport.
- Chizat, Oyallon, Bach (2019). On Lazy Training in Differentiable Programming.
- Chizat (2019). Sparse Optimization on Measures with Over-parameterized Gradient Descent.
28/29