modeling electronic health records with recurrent neural networks

Post on 21-Jan-2017

3.260 Views

Category:

Data & Analytics

3 Downloads

Preview:

Click to see full reader

TRANSCRIPT

Modeling electronic health records with

recurrent neural networks

David C. Kale,1,2 Zachary C. Lipton,3 Josh Patterson4

STRATA - San Jose - 20161 University of Southern California2 Virtual PICU, Children’s Hospital Los Angeles3 University of California San Diego4 Patterson Consulting

Outline• Machine (and deep) learning

• Sequence learning with recurrent neural networks

• Clinical sequence classification using LSTM RNNs

• A real world case study using DL4J

• Conclusion and looking forward

We need functions, brah

Various Inputs and Outputs

“Time underlies many interesting human behaviors”

{0,1}{A,B,C…}captions,email mom,fire nukes,eject pop tart

But how do we produce functions? We need a function

for that…

One function-generator: Programmers

Which are expensive

When/why does this fail?

• Sometimes the correct function cannot be encoded a priori — (what is spam?)

• The optimal solution might change over time

• Programmers are expensive

Sometimes We Need to Learn These Functions

From Data

One Class of Learnable Functions:

Feedforward Neural Network

Artificial Neurons

Activation Functions• At internal nodes common choices for the

activation function are the sigmoid, tanh, and ReLU functions.

• At output, activation function could be linear (regression), sigmoid (multilabel classification) or softmax (multi-class classification)

Training w Backpropagation

• Goal: calculate the rate of change of the loss function with respect to each parameter (weight) in the model

• Update the weights by gradient following:

Forward Pass

Backward Pass

Deep Networks• Used to be difficult (seemed impossible) to

train nets with many layers of hidden layers

• TLDR: Turns out we just needed to do everything 1000x faster…

Outline• Machine (and deep) learning

• Sequence learning with recurrent neural networks

• Clinical sequence classification using LSTM RNNs

• A real world case study using DL4J

• Conclusion and looking forward

Feedforward Nets work for Fixed-Size Data

Less Suitable for Text

We would like to capture temporal/sequential dynamics

in the data• Standard approaches address sequential

structure:Markov modelsConditional Random FieldsLinear dynamical systems

• Problem: We desire a system to learn representations, capture nonlinear structure, and capture long term sequential relationships.

To Model Sequential Data:

Recurrent Neural Networks

Recurrent Net (Unfolded)

Vanishing / Exploding Gradients

LSTM Memory Cell(Hochreiter & Schmidhuber, 1997)

Memory Cell with Forget Gate

(Gers et al., 2000)

LSTM Forward Pass

LSTM (full network)

Large Scale Architecture

Standard supervised learning

Imagecaptioning

Sentiment analysis

Video captioning,Natural language translation

Part of speechtagging

Generative models for text

Outline• Machine (and deep) learning

• Sequence learning with recurrent neural networks

• Clinical sequence classification using LSTM RNNs

• A real world case study using DL4J

• Conclusion and looking forward

ICU data generated in hospital

• Patient-level info (e.g., age, gender)• Physiologic measurements (e.g., heart rate)

– Manually verified observations– High-frequency measurements– Waveforms

• Lab results (e.g., glucose)• Clinical assessments (e.g., glasgow coma scale)• Medications and treatments• Clinical notes• Diagnoses• Outcomes• Billing codes

ICU data gathered in EHR

• Patient-level info (e.g., age, gender)• Physiologic measurements (e.g., vital signs)

– Manually verified observations– High-frequency measurements– Waveforms

• Lab results (e.g., glucose)• Clinical assessments (e.g., glasgow coma scale)• Medications and treatments • Clinical notes• Diagnoses (often buried in free text notes)• Outcomes• Billing codes

ICU data in our experiments

• Patient-level info (e.g., age, gender)• Physiologic measurements (e.g., vital signs)

– Manually verified observations– High-frequency measurements– Waveforms

• Lab results (e.g., glucose)• Clinical assessments (e.g., cognitive function)• One treatment: mechanical ventilation• Clinical notes• Diagnoses (often buried in free text notes)• Outcomes: in-hospital mortality• Billing codes

• Sparse, irregular, unaligned sampling in time, across variables

• Sample selection bias (e.g., more likely to record abnormal)

• Entire sequences (non-random) missing

HR

RR

Admit Discharge

Challenges: sampling rates, missingness

ETCO2

Figures courtesy of Ben Marlin, UMass Amherst

HR

HR

Admit

Admit

Discharge

Discharge

Challenges: alignment, variable length• Observations begin at time of admission, not at onset of

illness• Sequences vary in length from hours to weeks (or longer)• Variable dynamics across patients, even with same disease• Longterm dependencies: future state depends on earlier

condition

Figures courtesy of Ben Marlin, UMass Amherst

PhysioNet Challenge 2012• Task: predict mortality from only first 48 hours of data

• Classic models (SAPS, Apache, PRISM): experts features + regression• Useful: quantifying illness at admission, standardized performance• Not accurate enough to be used for decision support

• Each record includes• patient descriptors (age, gender, weight, height, unit)• irregular sequences of ~40 vitals, labs from first 48 hours• One treatment variable: mechanical ventilation• Binary outcome: in-hospital survival or mortality (~13% mortality)

• Only 4000 labeled records publicly available (“set A”)• 4000 unlabeled records (“set B”) used for tuning during competition (we

didn’t use)• 4000 test examples (“set C”) not available

• Very challenging task: temporal outcome, unobserved treatment effects• Winning entry score: minimum(Precision, Recall) = 0.5353

https://www.physionet.org/challenge/2012/

yt = σ(Vst + c)

st = φ(Wst-1 + Uxt + b)

PhysioNet Challenge 2012: predict in-hospital mortality from observations x1, x2, x3, …, xT during first 48 hours of ICU stay.

Solution: recurrent neural network (RNN)*p(ymort = 1 | x1, x2, x3, …, xT) ≈ p(ymort = 1 | sT), with st = f(st-1,

xt)

• Efficient parameterization: st represents exponential # states vs. # nodes

• Can encode (“remember”) longer histories• During learning, pass future info backward via backprop through

time

sT

yT

s2

y2

s1

y1

s0

x1 x2 xT

* We actually use a long short-term memory network

Outline• Machine (and deep) learning

• Sequence learning with recurrent neural networks

• Clinical sequence classification using LSTM RNNs

• A real world case study using DL4J

• Conclusion and looking forward

PhysioNet Raw Data• Set-a

– Directory of single files– One file per patient– 48 hours of ICU data

• Format– Header Line– 6 Descriptor Values at 00:00

• Collected at Admission – 37 Irregularly sampled columns

• Over 48 hours

Time,Parameter,Value00:00,RecordID,13260100:00,Age,7400:00,Gender,100:00,Height,177.800:00,ICUType,200:00,Weight,75.900:15,pH,7.3900:15,PaCO2,3900:15,PaO2,13700:56,pH,7.3900:56,PaCO2,3700:56,PaO2,22201:26,Urine,25001:26,Urine,63501:31,DiasABP,7001:31,FiO2,101:31,HR,10301:31,MAP,9401:31,MechVent,101:31,SysABP,15401:34,HCT,24.901:34,Platelets,11501:34,WBC,16.401:41,DiasABP,5201:41,HR,10201:41,MAP,6501:41,SysABP,9501:56,DiasABP,6401:56,GCS,301:56,HR,10401:56,MAP,8501:56,SysABP,132…

Preparing Input Data

• Input was 3D Tensor (3d Matrix)– Mini-batch as first dimension– Feature Columns as second dimension– Timesteps as third dimension

• At Mini-batch size of 20, 43 columns, and 202 Timesteps– We have 173,720 values per Tensor input

A Single Training Example

0 1 2 3 4 …

albumin 0.0 0.0 0.5 0.0 0.0

alp 0.0 0.1 0.0 0.0 0.2alt 0.0 0.0 0.0 0.9 0.0ast 0.0 0.0 0.0 0.0 0.4

timesteps

Vect

or c

olum

ns

Values

albumin 0.0

alp 1.0alt 0.5ast 0.0

Vect

or c

olum

ns

A single training example gets the added dimension of timesteps for each column

PhysioNet Timeseries Vectorization

@RELATION UnitTest_PhysioNet_Schema_ZUZUV@DELIMITER ,@MISSING_VALUE -1

@ATTRIBUTE recordid NOMINAL DESCRIPTOR !SKIP !ZERO @ATTRIBUTE age NUMERIC DESCRIPTOR !ZEROMEAN_ZEROUNITVARIANCE !AVG @ATTRIBUTE gender NUMERIC DESCRIPTOR !ZEROMEAN_ZEROUNITVARIANCE !ZERO @ATTRIBUTE height NUMERIC DESCRIPTOR !ZEROMEAN_ZEROUNITVARIANCE !AVG @ATTRIBUTE weight NUMERIC DESCRIPTOR !ZEROMEAN_ZEROUNITVARIANCE !AVG @ATTRIBUTE icutype NUMERIC DESCRIPTOR !ZEROMEAN_ZEROUNITVARIANCE !ZERO @ATTRIBUTE albumin NUMERIC TIMESERIES !ZEROMEAN_ZEROUNITVARIANCE !PAD_TAIL_WITH_ZEROS @ATTRIBUTE alp NUMERIC TIMESERIES !ZEROMEAN_ZEROUNITVARIANCE !PAD_TAIL_WITH_ZEROS @ATTRIBUTE alt NUMERIC TIMESERIES !ZEROMEAN_ZEROUNITVARIANCE !PAD_TAIL_WITH_ZEROS @ATTRIBUTE ast NUMERIC TIMESERIES !ZEROMEAN_ZEROUNITVARIANCE !PAD_TAIL_WITH_ZEROS @ATTRIBUTE bilirubin NUMERIC TIMESERIES !ZEROMEAN_ZEROUNITVARIANCE !PAD_TAIL_WITH_ZEROS

[ more … ]

Uneven Time Steps and Masking

0 1 2 3 4 …albumin 0.0 0.0 0.5 0.0 0.0

alp 0.0 0.1 0.0 0.0 0.0

alt 0.0 0.0 0.0 0.9 0.0

ast 0.0 0.0 0.0 0.0 0.0

1.0 1.0 1.0 1.0 0.0 0.0

Single Input

(columns + timesteps)

Input Mask

(only timesteps)

DL4J• “The Hadoop of Deep Learning”

– Command line driven– Java, Scala, and Python APIs– ASF 2.0 Licensed

• Java implementation– Parallelization (Yarn, Spark)– GPU support

• Also Supports multi-GPU per host• Runtime Neutral

– Local– Hadoop / YARN + Spark– AWS

• https://github.com/deeplearning4j/deeplearning4j

RNNs in DL4JMultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()

.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)

.learningRate( learningRate )

.rmsDecay(0.95)

.seed(12345)

.regularization(true)

.l2(0.001)

.list(3).layer(0, new

GravesLSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize).updater(Updater.RMSPROP).activation("tanh").weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(-0.08, 0.08)).build())

.layer(1, new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize)

.updater(Updater.RMSPROP)

.activation("tanh").weightInit(WeightInit.DISTRIBUTION)

.dist(new UniformDistribution(-0.08, 0.08)).build()).layer(2, new

RnnOutputLayer.Builder(LossFunction.MCXENT).activation("softmax”).updater(Updater.RMSPROP)

.nIn(lstmLayerSize).nOut(nOut).weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(-0.08, 0.08)).build())

.pretrain(false).backprop(true)

.build();

for (int epoch = 0; epoch < max_epochs; ++epoch)net.fit(dataset_iter);

Experimental Results• Winning entry: min(P,R) = 0.5353 (two others over 0.5)

• Trained on full set A (4K), tuned on set B (4K), tested on set C

• All used extensively hand-engineered features

• Our best model so far: min(P,R) = 0.4907• 60/20/20 training/validation/test split of set A• LSTM with 2 x 300-cell layers on inputs

• Different test sets so not directly comparable• Disadvantage: much smaller training set• Required no feature engineering or domain knowledge

Map sequences into fixed vector representation

• Not perfectly separable in 2D but some cluster structure related to mortality

• Can repurpose “representation” for other tasks (e.g., searching for similar patients, clustering, etc.)

Final comments• We believe we could improve performance to well over 0.5

• overfitting: training min(P,R) > 0.6 (vs. test: 0.49)• smaller or simpler RNN layers, adding dropout, multitask training

• Flexible NN architectures well suited to complex clinical data• but likely will demand much larger data sets• may be better matched to “raw” signals (e.g., waveforms)

• More general challenges• missing (or unobserved) inputs and outcomes• treatment effects confound predictive models• outcomes often have temporal components

(posing as binary classification ignores that)

• You can try it out: https://github.com/jpatanooga/dl4j-rnn-timeseries-examples/

See related paper to appear at ICLR 2016: http://arxiv.org/abs/1511.03677

Questions?Thank you for your time and attention

Gibson & Patterson. Deep Learning: A Practitioner’s Approach. O’Reilly, Q2 2016.

Lipton, et al. A Critical Review of RNNs. arXiv.

Lipton & Kale. Learning to Diagnose with LSTM RNNs. ICLR 2016.

Sepp HochreiterFather of LSTMs,* renowned beer thief

S. Hochreiter and J. Schmidhuber. Long short-term memory. Neural Computation, 9 (8): 1735-1780, 1997.

top related