modeling electronic health records with recurrent neural networks
TRANSCRIPT
Modeling electronic health records with
recurrent neural networks
David C. Kale,1,2 Zachary C. Lipton,3 Josh Patterson4
STRATA - San Jose - 20161 University of Southern California2 Virtual PICU, Children’s Hospital Los Angeles3 University of California San Diego4 Patterson Consulting
Outline• Machine (and deep) learning
• Sequence learning with recurrent neural networks
• Clinical sequence classification using LSTM RNNs
• A real world case study using DL4J
• Conclusion and looking forward
We need functions, brah
Various Inputs and Outputs
“Time underlies many interesting human behaviors”
{0,1}{A,B,C…}captions,email mom,fire nukes,eject pop tart
But how do we produce functions? We need a function
for that…
One function-generator: Programmers
Which are expensive
When/why does this fail?
• Sometimes the correct function cannot be encoded a priori — (what is spam?)
• The optimal solution might change over time
• Programmers are expensive
Sometimes We Need to Learn These Functions
From Data
One Class of Learnable Functions:
Feedforward Neural Network
Artificial Neurons
Activation Functions• At internal nodes common choices for the
activation function are the sigmoid, tanh, and ReLU functions.
• At output, activation function could be linear (regression), sigmoid (multilabel classification) or softmax (multi-class classification)
Training w Backpropagation
• Goal: calculate the rate of change of the loss function with respect to each parameter (weight) in the model
• Update the weights by gradient following:
Forward Pass
Backward Pass
Deep Networks• Used to be difficult (seemed impossible) to
train nets with many layers of hidden layers
• TLDR: Turns out we just needed to do everything 1000x faster…
Outline• Machine (and deep) learning
• Sequence learning with recurrent neural networks
• Clinical sequence classification using LSTM RNNs
• A real world case study using DL4J
• Conclusion and looking forward
Feedforward Nets work for Fixed-Size Data
Less Suitable for Text
We would like to capture temporal/sequential dynamics
in the data• Standard approaches address sequential
structure:Markov modelsConditional Random FieldsLinear dynamical systems
• Problem: We desire a system to learn representations, capture nonlinear structure, and capture long term sequential relationships.
To Model Sequential Data:
Recurrent Neural Networks
Recurrent Net (Unfolded)
Vanishing / Exploding Gradients
LSTM Memory Cell(Hochreiter & Schmidhuber, 1997)
Memory Cell with Forget Gate
(Gers et al., 2000)
LSTM Forward Pass
LSTM (full network)
Large Scale Architecture
Standard supervised learning
Imagecaptioning
Sentiment analysis
Video captioning,Natural language translation
Part of speechtagging
Generative models for text
Outline• Machine (and deep) learning
• Sequence learning with recurrent neural networks
• Clinical sequence classification using LSTM RNNs
• A real world case study using DL4J
• Conclusion and looking forward
ICU data generated in hospital
• Patient-level info (e.g., age, gender)• Physiologic measurements (e.g., heart rate)
– Manually verified observations– High-frequency measurements– Waveforms
• Lab results (e.g., glucose)• Clinical assessments (e.g., glasgow coma scale)• Medications and treatments• Clinical notes• Diagnoses• Outcomes• Billing codes
ICU data gathered in EHR
• Patient-level info (e.g., age, gender)• Physiologic measurements (e.g., vital signs)
– Manually verified observations– High-frequency measurements– Waveforms
• Lab results (e.g., glucose)• Clinical assessments (e.g., glasgow coma scale)• Medications and treatments • Clinical notes• Diagnoses (often buried in free text notes)• Outcomes• Billing codes
ICU data in our experiments
• Patient-level info (e.g., age, gender)• Physiologic measurements (e.g., vital signs)
– Manually verified observations– High-frequency measurements– Waveforms
• Lab results (e.g., glucose)• Clinical assessments (e.g., cognitive function)• One treatment: mechanical ventilation• Clinical notes• Diagnoses (often buried in free text notes)• Outcomes: in-hospital mortality• Billing codes
• Sparse, irregular, unaligned sampling in time, across variables
• Sample selection bias (e.g., more likely to record abnormal)
• Entire sequences (non-random) missing
HR
RR
Admit Discharge
Challenges: sampling rates, missingness
ETCO2
Figures courtesy of Ben Marlin, UMass Amherst
HR
HR
Admit
Admit
Discharge
Discharge
Challenges: alignment, variable length• Observations begin at time of admission, not at onset of
illness• Sequences vary in length from hours to weeks (or longer)• Variable dynamics across patients, even with same disease• Longterm dependencies: future state depends on earlier
condition
Figures courtesy of Ben Marlin, UMass Amherst
PhysioNet Challenge 2012• Task: predict mortality from only first 48 hours of data
• Classic models (SAPS, Apache, PRISM): experts features + regression• Useful: quantifying illness at admission, standardized performance• Not accurate enough to be used for decision support
• Each record includes• patient descriptors (age, gender, weight, height, unit)• irregular sequences of ~40 vitals, labs from first 48 hours• One treatment variable: mechanical ventilation• Binary outcome: in-hospital survival or mortality (~13% mortality)
• Only 4000 labeled records publicly available (“set A”)• 4000 unlabeled records (“set B”) used for tuning during competition (we
didn’t use)• 4000 test examples (“set C”) not available
• Very challenging task: temporal outcome, unobserved treatment effects• Winning entry score: minimum(Precision, Recall) = 0.5353
https://www.physionet.org/challenge/2012/
yt = σ(Vst + c)
st = φ(Wst-1 + Uxt + b)
PhysioNet Challenge 2012: predict in-hospital mortality from observations x1, x2, x3, …, xT during first 48 hours of ICU stay.
Solution: recurrent neural network (RNN)*p(ymort = 1 | x1, x2, x3, …, xT) ≈ p(ymort = 1 | sT), with st = f(st-1,
xt)
• Efficient parameterization: st represents exponential # states vs. # nodes
• Can encode (“remember”) longer histories• During learning, pass future info backward via backprop through
time
sT
yT
s2
y2
s1
y1
s0
x1 x2 xT
* We actually use a long short-term memory network
Outline• Machine (and deep) learning
• Sequence learning with recurrent neural networks
• Clinical sequence classification using LSTM RNNs
• A real world case study using DL4J
• Conclusion and looking forward
PhysioNet Raw Data• Set-a
– Directory of single files– One file per patient– 48 hours of ICU data
• Format– Header Line– 6 Descriptor Values at 00:00
• Collected at Admission – 37 Irregularly sampled columns
• Over 48 hours
Time,Parameter,Value00:00,RecordID,13260100:00,Age,7400:00,Gender,100:00,Height,177.800:00,ICUType,200:00,Weight,75.900:15,pH,7.3900:15,PaCO2,3900:15,PaO2,13700:56,pH,7.3900:56,PaCO2,3700:56,PaO2,22201:26,Urine,25001:26,Urine,63501:31,DiasABP,7001:31,FiO2,101:31,HR,10301:31,MAP,9401:31,MechVent,101:31,SysABP,15401:34,HCT,24.901:34,Platelets,11501:34,WBC,16.401:41,DiasABP,5201:41,HR,10201:41,MAP,6501:41,SysABP,9501:56,DiasABP,6401:56,GCS,301:56,HR,10401:56,MAP,8501:56,SysABP,132…
Preparing Input Data
• Input was 3D Tensor (3d Matrix)– Mini-batch as first dimension– Feature Columns as second dimension– Timesteps as third dimension
• At Mini-batch size of 20, 43 columns, and 202 Timesteps– We have 173,720 values per Tensor input
A Single Training Example
0 1 2 3 4 …
albumin 0.0 0.0 0.5 0.0 0.0
alp 0.0 0.1 0.0 0.0 0.2alt 0.0 0.0 0.0 0.9 0.0ast 0.0 0.0 0.0 0.0 0.4
…
timesteps
Vect
or c
olum
ns
Values
albumin 0.0
alp 1.0alt 0.5ast 0.0
…
Vect
or c
olum
ns
A single training example gets the added dimension of timesteps for each column
PhysioNet Timeseries Vectorization
@RELATION UnitTest_PhysioNet_Schema_ZUZUV@DELIMITER ,@MISSING_VALUE -1
@ATTRIBUTE recordid NOMINAL DESCRIPTOR !SKIP !ZERO @ATTRIBUTE age NUMERIC DESCRIPTOR !ZEROMEAN_ZEROUNITVARIANCE !AVG @ATTRIBUTE gender NUMERIC DESCRIPTOR !ZEROMEAN_ZEROUNITVARIANCE !ZERO @ATTRIBUTE height NUMERIC DESCRIPTOR !ZEROMEAN_ZEROUNITVARIANCE !AVG @ATTRIBUTE weight NUMERIC DESCRIPTOR !ZEROMEAN_ZEROUNITVARIANCE !AVG @ATTRIBUTE icutype NUMERIC DESCRIPTOR !ZEROMEAN_ZEROUNITVARIANCE !ZERO @ATTRIBUTE albumin NUMERIC TIMESERIES !ZEROMEAN_ZEROUNITVARIANCE !PAD_TAIL_WITH_ZEROS @ATTRIBUTE alp NUMERIC TIMESERIES !ZEROMEAN_ZEROUNITVARIANCE !PAD_TAIL_WITH_ZEROS @ATTRIBUTE alt NUMERIC TIMESERIES !ZEROMEAN_ZEROUNITVARIANCE !PAD_TAIL_WITH_ZEROS @ATTRIBUTE ast NUMERIC TIMESERIES !ZEROMEAN_ZEROUNITVARIANCE !PAD_TAIL_WITH_ZEROS @ATTRIBUTE bilirubin NUMERIC TIMESERIES !ZEROMEAN_ZEROUNITVARIANCE !PAD_TAIL_WITH_ZEROS
[ more … ]
Uneven Time Steps and Masking
0 1 2 3 4 …albumin 0.0 0.0 0.5 0.0 0.0
alp 0.0 0.1 0.0 0.0 0.0
alt 0.0 0.0 0.0 0.9 0.0
ast 0.0 0.0 0.0 0.0 0.0
…
1.0 1.0 1.0 1.0 0.0 0.0
Single Input
(columns + timesteps)
Input Mask
(only timesteps)
DL4J• “The Hadoop of Deep Learning”
– Command line driven– Java, Scala, and Python APIs– ASF 2.0 Licensed
• Java implementation– Parallelization (Yarn, Spark)– GPU support
• Also Supports multi-GPU per host• Runtime Neutral
– Local– Hadoop / YARN + Spark– AWS
• https://github.com/deeplearning4j/deeplearning4j
RNNs in DL4JMultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
.learningRate( learningRate )
.rmsDecay(0.95)
.seed(12345)
.regularization(true)
.l2(0.001)
.list(3).layer(0, new
GravesLSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize).updater(Updater.RMSPROP).activation("tanh").weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(-0.08, 0.08)).build())
.layer(1, new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize)
.updater(Updater.RMSPROP)
.activation("tanh").weightInit(WeightInit.DISTRIBUTION)
.dist(new UniformDistribution(-0.08, 0.08)).build()).layer(2, new
RnnOutputLayer.Builder(LossFunction.MCXENT).activation("softmax”).updater(Updater.RMSPROP)
.nIn(lstmLayerSize).nOut(nOut).weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(-0.08, 0.08)).build())
.pretrain(false).backprop(true)
.build();
for (int epoch = 0; epoch < max_epochs; ++epoch)net.fit(dataset_iter);
Experimental Results• Winning entry: min(P,R) = 0.5353 (two others over 0.5)
• Trained on full set A (4K), tuned on set B (4K), tested on set C
• All used extensively hand-engineered features
• Our best model so far: min(P,R) = 0.4907• 60/20/20 training/validation/test split of set A• LSTM with 2 x 300-cell layers on inputs
• Different test sets so not directly comparable• Disadvantage: much smaller training set• Required no feature engineering or domain knowledge
Map sequences into fixed vector representation
• Not perfectly separable in 2D but some cluster structure related to mortality
• Can repurpose “representation” for other tasks (e.g., searching for similar patients, clustering, etc.)
Final comments• We believe we could improve performance to well over 0.5
• overfitting: training min(P,R) > 0.6 (vs. test: 0.49)• smaller or simpler RNN layers, adding dropout, multitask training
• Flexible NN architectures well suited to complex clinical data• but likely will demand much larger data sets• may be better matched to “raw” signals (e.g., waveforms)
• More general challenges• missing (or unobserved) inputs and outcomes• treatment effects confound predictive models• outcomes often have temporal components
(posing as binary classification ignores that)
• You can try it out: https://github.com/jpatanooga/dl4j-rnn-timeseries-examples/
See related paper to appear at ICLR 2016: http://arxiv.org/abs/1511.03677
Questions?Thank you for your time and attention
Gibson & Patterson. Deep Learning: A Practitioner’s Approach. O’Reilly, Q2 2016.
Lipton, et al. A Critical Review of RNNs. arXiv.
Lipton & Kale. Learning to Diagnose with LSTM RNNs. ICLR 2016.
Sepp HochreiterFather of LSTMs,* renowned beer thief
S. Hochreiter and J. Schmidhuber. Long short-term memory. Neural Computation, 9 (8): 1735-1780, 1997.