lecture 2: gradient estimators - github pages · lecture 2: gradient estimators csc 2547 spring...

31
Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu and GeoRoeder

Upload: others

Post on 04-Jul-2020

7 views

Category:

Documents


0 download

TRANSCRIPT

Page 1: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Lecture 2: Gradient Estimators

CSC 2547 Spring 2018David Duvenaud

Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu and Geoff Roeder

Page 2: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Where do we see this guy?

• Just about everywhere!

• Variational Inference

• Reinforcement Learning

• Hard Attention

• And so many more!

L(✓) = Ep(b|✓)[f(b)]

Page 3: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Gradient based optimization

• Gradient based optimization is the standard method used today to optimize expectations

• Necessary if models are neural-net based

• Very rarely can this gradient be computed analytically

Page 4: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Otherwise, we estimate…

• A number of approaches exist to estimate this gradient

• They make varying levels of assumptions about the distribution and function being optimized

• Most popular methods either make strong assumptions or suffer from high variance

Page 5: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

REINFORCE (Williams, 1992)

• Unbiased

• Has few requirements

• Easy to compute

• Suffers from high variance

gREINFORCE[f ] = f (b) @@✓ log p(b|✓), b ⇠ p(b|✓)

Page 6: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Reparameterization (Kingma & Welling, 2014)

• Lower variance empirically

• Unbiased

• Makes stronger assumptions

• Requires is known and differentiable

• Requires is reparameterizable

greparam[f ] =@f@b

@b@✓ b = T (✓, ✏), ✏ ⇠ p(✏)

f(b)

p(b|✓)

Page 7: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Concrete (Maddison et al., 2016)

• Works well in practice

• Low variance from reparameterization

• Biased

• Adds temperature hyper-parameter

• Requires that is known, and differentiable

• Requires is reparameterizable

• Requires behaves predictably outside of domain

gconcrete

[f ] = @f@�(z/t)

@�(z/t)@✓ z = T (✓, ✏), ✏ ⇠ p(✏)

p(z|✓)

f(b)

f(b)

Page 8: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Control Variates• Allow us to reduce variance of a Monte Carlo estimator

• Variance is reduced if

• Does not change bias

gnew(b) = g(b)� c(b) + Ep(b)[c(b)]

corr(g, c) > 0

Page 9: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Putting it all together

• We would like a general gradient estimator that is

• unbiased

• low variance

• usable when is unknown

• useable when is discrete

f(b)

p(b|✓)

Page 10: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Backpropagation Through

Page 11: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Backpropagation Through

Page 12: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Backpropagation Through

Will Grathwohl Dami Choi Yuhuai Wu

Geoff Roeder David Duvenaud

Page 13: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Our ApproachgLAX = gREINFORCE[f ]� gREINFORCE[c�] + greparam[c�]

Page 14: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Our Approach

• Start with the reinforce estimator for

• We introduce a new function

• We subtract the reinforce estimator of its gradient and add the reparameterization estimator

• Can be thought of as using the reinforce estimator of as a control variate

f(b)

c�(b)

c�(b)

gLAX = gREINFORCE[f ]� gREINFORCE[c�] + greparam[c�]

= [f(b)� c�(b)]@@✓ log p(b|✓) +

@@✓ c�(b) b = T (✓, ✏), ✏ ⇠ p(✏)

Page 15: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Optimizing the Control Variate

• For any unbiased estimator we can get Monte Carlo estimates for the gradient of the variance of

• Use to optimize

g

c�

@

@�Variance(g) = E

@

@�g2�

Page 16: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

What about discrete b?

Page 17: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Extension to discrete

• When b is discrete, we introduce a relaxed distribution and a function where

• We use the conditioning scheme introduced in REBAR (Tucker et al. 2017)

• Unbiased for all

gRELAX = [f(b)� c�(z)]@@✓ log p(b|✓) +

@@✓ c�(z)�

@@✓ c�(z)

b = H(z), z ⇠ p(z|✓), z ⇠ p(z|b, ✓)

p(b|✓)

p(z|✓)

c�

H(z) = b ⇠ p(b|✓)H(z) = b ⇠ p(b|✓)

Page 18: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

A Simple Example

• Used to validate REBAR (used t = .45)

• We use t = .499

• REBAR, REINFORCE fail due to noise outweighing signal

• Can RELAX improve?

Ep(b|✓)[(t� b)2]

Page 19: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

• RELAX outperforms baselines

• Considerably reduced variance!

• RELAX learns reasonable surrogate

Page 20: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Analyzing the Surrogate

• REBAR’s fixed surrogate cannot produce consistent and correct gradients

• RELAX learns to balance REINFORCE variance and reparameterization variance

Page 21: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

A More Interesting Application

• Discrete VAE

• Latent state is 200 Bernoulli variables

• Discrete sampling makes reparameterization estimator unusable

log p(x) � L(✓) = Eq(b|x)[log p(x|b) + log p(b)� log q(b|x)]

c�(z) = f(��(z)) + r⇢(z)

Page 22: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Results

Page 23: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Reinforcement Learning

• Policy gradient methods are very popular today (A2C, A3C, ACKTR)

• Seeks to find

• Does this by estimating

• R is not known so many popular estimators cannot be used

argmax✓E⌧⇠⇡(⌧ |✓)[R(⌧)]

@@✓E⌧⇠⇡(⌧ |✓)[R(⌧)]

Page 24: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Actor Critic

• is an estimate of the value function

• This is exactly the REINFORCE estimator using an estimate of the value function as a control variate

• Why not use action in control variate?

• Dependence on action would add bias

c�

gAC =

TX

t=1

@ log ⇡(at|st, ✓)@✓

"TX

t0=t

rt0 � c�(st)

#

Page 25: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

LAX for RL

• Allows for action dependence in control variate

• Remains unbiased

• Similar extension available for discrete action spaces

gLAX =

TX

t=1

@ log ⇡(at|st, ✓)@✓

"TX

t0=t

rt0 � c�(st, at)

#+

@

@✓c�(st, at)c�(st, at)gLAX =

TX

t=1

@ log ⇡(at|st, ✓)@✓

"TX

t0=t

rt0 � c�(st, at)

#+

@

@✓c�(st, at)

@

@✓c�(st, at)

Page 26: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Results

• Improved performance

• Lower variance gradient estimates

Page 27: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Future Work• What does the optimal surrogate look like?

• Many possible variations of LAX and RELAX

• Which provides the best tradeoff between variance, ease of implementation, scope of application, performance

• RL

• Incorporate other variance reduction techniques (GAE, reward bootstrapping, trust-region)

• Ways to train the surrogate off-policy

• Applications

• Inference of graph structure (coming soon)

• Inference of discrete neural network architecture components (coming soon)

Page 28: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Directions

• Surrogate can take any form

• can rely on global information even if forward pass only uses local info

• Can depend on order even if forward pass is invariant

• Reparameterization can take many forms, ongoing work on reparameterizing through rejection sampling, or distributions on permutations

Page 29: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Reparameterizing the Birkhoff Polytope forVariational Permutation Inference

Page 30: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Learning Latent Permutations with Gumbel-Sinkhorn Networks

Page 31: Lecture 2: Gradient Estimators - GitHub Pages · Lecture 2: Gradient Estimators CSC 2547 Spring 2018 David Duvenaud Based mainly on slides by Will Grathwohl, Dami Choi, Yuhuai Wu

Why are we optimizing policies anyways?

• Next week: Variational optimization