mohammad reza keshtkaran, chethan pandarinath · mohammad reza keshtkaran, chethan pandarinath...
TRANSCRIPT
1850 3000 4150Sample validation loss
00.
20.
40.
60.
8R2 18500.
70.
8
921843681836# of training trials
00.
20.
40.
60.
81
R2
Y-velocity
5% of trials used
TimeTrial
Neu
rons
Held-out samples
Held-out inferred rate
921843681836# of training trials
020
4060
%
R2 improvement
92184368572# of training trials
020
4060
%
R2 improvement
92184368572# of training trials
00.
20.
40.
60.
8R2
X-velocity
16% of trials used
92184368572# of training trials
00.
20.
40.
60.
8R2
Y-velocity
16% of trials used
921843681836# of training trials
00.
20.
40.
60.
81
R2
X-velocity
5% of trials used
Smoothing
Fixed HPs
HP-optimized
Sequential autoencoder (SAE) for modeling neural dynamics
Pathological overfitting in SAEs Denoising autoencoders do not address overfitting in SAEs
Two solutions to address pathological overfitting in SAEs
2- Sample validation
1- Coordinated dropout (CD)
HP optimization trains accurate models with 5-10x less data
Conclusions
References
Experimental setup - decoding arm velocity
Dataset 1: Curved reaching task
Two tasks:1- Curved reaching2- Random target
HP optimization performed using Population based training (PBT)3
Dataset 2: Random target task
2.3
1.4
0.1
2.0
0.5
2.9
1.6
2.3
1.4
0.1
2.0
0.5
2.9
1.6
Train on these samples
To reconstruct these samples
Training example
Training example Training step n
SAE
sample droppedgradient blocked
ft Wrate
Wfac
rtInitial state g0
time time
trials
neur
ons
5.4 1.3
1.2 0.9
1.7 1.1... ...
µ σ
0.9 1.4
Neuron 1Neuron 2Neuron 3
Neuron n
Decode arm velocity
(linear decoder)Record neural activity
Observed spikes
True trajectory
Estimated trajectory
Estimated spike rates
SAE
Enabling hyperparameter optimization in sequential autoencoders for spiking neural dataMohammad Reza Keshtkaran, Chethan PandarinathGeorgia Tech and Emory University Email: [email protected]
Evaluate the network by how well it can predict the rates for the samples it has never seen during training or evaluation.Simulation setup2
TimeTime
Neu
rons
Neu
rons
Time
Neu
rons
Compare
ft Wrate
Wfac
rt
Initial state g0
time
time time
trials
neur
ons
neur
ons
neur
ons
5.4 1.3
1.2 0.9
1.7 1.1
... ...
µ σ
SAENonlinear dynamical
system
Evaluate using R2
Ground truth firing rates
Synthetic data generation
Estimation of firing rates
Poisson sampling (noisy)
Synthetic spikes
Inferred firing rates
Trained with coordinated dropoutStandard training
Z = W1T X
encoder
latent variable
Q ∈ ℝmxn
X ∈ ℝoxn
Z ∈ ℝhxn
m < o ≤ h
Reconstruction
low-d data:high-d observation:
latent variable:
decoder auto-encoderX′ = W2
Z X′ = W2 W1
T X
CD is equivalent to preventing diagonal weights to be trained.
Simulation example, linear over-complete autoencoder:
W
W: W:
0 50 100 150 200Epoch
0
1
2
3
4
Loss
(MSE
) Train lossTrue LossLower bound
0 50 100 150 200Epoch0
1
2
3
4
Loss
(MSE
)
Y = WpTQ
X = Y + N where N ~ N(0, σ)
True loss: ||Y-X′||2Train loss: ||X-X′||2
Train lossTrue LossLower bound
DecoderEncoders
Controller
Reconstruction loss(Poisson)
Observed Spikes
Initial State
Inferred Inputs
Inferred Rates
Latent Factors
Main hyperparameters:- KL penalies- L2 penalies- Dropout- Learning rate
Maximize Evidence Lower Bound ELBO (L):
Reconstruction log-likelihood:
KL penalty:
Acknowledgements: This work was supported NSF NCS 1835364, and DARPA Intelligent Neural Interfaces program.
[1] Pandarinath, Chethan, Daniel J. O’Shea, Jasmine Collins, et al. “Inferring Single-Trial Neural Population Dynamics Using Sequential Auto-Encoders.” Nature Methods, 2018.[2] Sussillo, David, Rafal Jozefowicz, L. F. Abbott, and Chethan Pandarinath. "Lfads-latent factor analysis via dynamical systems." arXiv preprint arXiv:1608.06315 (2016).[3] Jaderberg, Max, Valentin Dalibard, Simon Osindero, et al. "Population based training of neural networks." arXiv preprint arXiv:1711.09846 (2017).
• SAEs are prone to a particular type of overfitting that cannot be detected through using standard validation loss.
• Lack of a reliable validation metric prevented HP optimization in SAEs.• We developed two solutions “Coordinated dropout” and “Sample validation” to
address pathological overfitting in SAEs and enable HP optimization.• HP optimization led to accurate models while using 5-10-fold less training data.
CD forces the network to only model shared structure underlying the observa-tions. CD first passes in a subset of samples at the input (by applying dropout). Next, to update network weights, CD only uses gradients calculated for recon-struction of the complementary subset of samples.
Used synthetic data with ground truth to demonstrate pathological overfitting.
True
per
form
ance
1350 1850 2350Standard validation loss
00.
20.
40.
60.
8R2
18500.7
0.8
True
per
form
ance
1200 1600 2000 2400Validation loss
00.
40.
8R2
Zero Masking0%10%30%50%
True
per
form
ance
1200 1600 2000 2400Validation loss
Salt and Pepper0%10%20%30%
True
per
form
ance
00.
40.
8
SAE Model: Latent factor analysis through dynamical systems (LFADS)1
Better models
Overfittin
g
Better Loss
200 models, random hyperparams
spikes
True rate
Standard validation loss1200 1600 2000 2400
Overfit Good fit Underfit
00.
20.
40.
60.
8Tr
ue p
erfo
rman
ce
Time
Neu
rons
Time
Neu
rons
Compare
ft Wrate
Wfac
rt
Initial state g0
time
time time
trialsne
uron
s
neur
ons
neur
ons
5.4 1.3
1.2 0.9
1.7 1.1
... ...
µ σ
SAE
Neural spikes (point process)
Neural spikes
Dec
odin
g A
ccur
acy
Less training data
Typical data size in neuroscience
Time
Neu
rons
MaximizePoisson
likelihood
Record neural activity
Inferred rates
True hand positionEstimated hand position from neural activityDecode
behavior
Decreasing data size Decreasing data size
Decreasing data sizeDecreasing data size
CD OffCD On
Smoothing
SAE
Example rate for one neuron
Fixed HPs
HP-optimizedDecoded movements