semi-supervised learning with deep generative models · semi-supervised learning with deep...

9
Semi-supervised Learning with Deep Generative Models Diederik P. Kingma * , Danilo J. Rezende , Shakir Mohamed , Max Welling * * Machine Learning Group, Univ. of Amsterdam, {D.P.Kingma, M.Welling}@uva.nl Google Deepmind, {danilor, shakir}@google.com Abstract The ever-increasing size of modern data sets combined with the difficulty of ob- taining label information has made semi-supervised learning one of the problems of significant practical importance in modern data analysis. We revisit the ap- proach to semi-supervised learning with generative models and develop new mod- els that allow for effective generalisation from small labelled data sets to large unlabelled ones. Generative approaches have thus far been either inflexible, in- efficient or non-scalable. We show that deep generative models and approximate Bayesian inference exploiting recent advances in variational methods can be used to provide significant improvements, making generative approaches highly com- petitive for semi-supervised learning. 1 Introduction Semi-supervised learning considers the problem of classification when only a small subset of the observations have corresponding class labels. Such problems are of immense practical interest in a wide range of applications, including image search (Fergus et al., 2009), genomics (Shi and Zhang, 2011), natural language parsing (Liang, 2005), and speech analysis (Liu and Kirchhoff, 2013), where unlabelled data is abundant, but obtaining class labels is expensive or impossible to obtain for the entire data set. The question that is then asked is: how can properties of the data be used to improve decision boundaries and to allow for classification that is more accurate than that based on classifiers constructed using the labelled data alone. In this paper we answer this question by developing probabilistic models for inductive and transductive semi-supervised learning by utilising an explicit model of the data density, building upon recent advances in deep generative models and scalable variational inference (Kingma and Welling, 2014; Rezende et al., 2014). Amongst existing approaches, the simplest algorithm for semi-supervised learning is based on a self-training scheme (Rosenberg et al., 2005) where the the model is bootstrapped with additional labelled data obtained from its own highly confident predictions; this process being repeated until some termination condition is reached. These methods are heuristic and prone to error since they can reinforce poor predictions. Transductive SVMs (TSVM) (Joachims, 1999) extend SVMs with the aim of max-margin classification while ensuring that there are as few unlabelled observations near the margin as possible. These approaches have difficulty extending to large amounts of unla- belled data, and efficient optimisation in this setting is still an open problem. Graph-based methods are amongst the most popular and aim to construct a graph connecting similar observations; label information propagates through the graph from labelled to unlabelled nodes by finding the minimum energy (MAP) configuration (Blum et al., 2004; Zhu et al., 2003). Graph-based approaches are sen- sitive to the graph structure and require eigen-analysis of the graph Laplacian, which limits the scale to which these methods can be applied – though efficient spectral methods are now available (Fer- gus et al., 2009). Neural network-based approaches combine unsupervised and supervised learning For an updated version of this paper, please see http://arxiv.org/abs/1406.5298 1 arXiv:1406.5298v2 [cs.LG] 31 Oct 2014

Upload: tranthuan

Post on 26-Jul-2018

220 views

Category:

Documents


1 download

TRANSCRIPT

Page 1: Semi-supervised Learning with Deep Generative Models · Semi-supervised Learning with Deep Generative Models Diederik P. Kingma , Danilo J. Rezende y, Shakir Mohamed , Max Welling

Semi-supervised Learning withDeep Generative Models

Diederik P. Kingma∗, Danilo J. Rezende†, Shakir Mohamed†, Max Welling∗∗Machine Learning Group, Univ. of Amsterdam, {D.P.Kingma, M.Welling}@uva.nl

†Google Deepmind, {danilor, shakir}@google.com

Abstract

The ever-increasing size of modern data sets combined with the difficulty of ob-taining label information has made semi-supervised learning one of the problemsof significant practical importance in modern data analysis. We revisit the ap-proach to semi-supervised learning with generative models and develop new mod-els that allow for effective generalisation from small labelled data sets to largeunlabelled ones. Generative approaches have thus far been either inflexible, in-efficient or non-scalable. We show that deep generative models and approximateBayesian inference exploiting recent advances in variational methods can be usedto provide significant improvements, making generative approaches highly com-petitive for semi-supervised learning.

1 IntroductionSemi-supervised learning considers the problem of classification when only a small subset of theobservations have corresponding class labels. Such problems are of immense practical interest in awide range of applications, including image search (Fergus et al., 2009), genomics (Shi and Zhang,2011), natural language parsing (Liang, 2005), and speech analysis (Liu and Kirchhoff, 2013), whereunlabelled data is abundant, but obtaining class labels is expensive or impossible to obtain for theentire data set. The question that is then asked is: how can properties of the data be used to improvedecision boundaries and to allow for classification that is more accurate than that based on classifiersconstructed using the labelled data alone. In this paper we answer this question by developingprobabilistic models for inductive and transductive semi-supervised learning by utilising an explicitmodel of the data density, building upon recent advances in deep generative models and scalablevariational inference (Kingma and Welling, 2014; Rezende et al., 2014).

Amongst existing approaches, the simplest algorithm for semi-supervised learning is based on aself-training scheme (Rosenberg et al., 2005) where the the model is bootstrapped with additionallabelled data obtained from its own highly confident predictions; this process being repeated untilsome termination condition is reached. These methods are heuristic and prone to error since theycan reinforce poor predictions. Transductive SVMs (TSVM) (Joachims, 1999) extend SVMs withthe aim of max-margin classification while ensuring that there are as few unlabelled observationsnear the margin as possible. These approaches have difficulty extending to large amounts of unla-belled data, and efficient optimisation in this setting is still an open problem. Graph-based methodsare amongst the most popular and aim to construct a graph connecting similar observations; labelinformation propagates through the graph from labelled to unlabelled nodes by finding the minimumenergy (MAP) configuration (Blum et al., 2004; Zhu et al., 2003). Graph-based approaches are sen-sitive to the graph structure and require eigen-analysis of the graph Laplacian, which limits the scaleto which these methods can be applied – though efficient spectral methods are now available (Fer-gus et al., 2009). Neural network-based approaches combine unsupervised and supervised learning

For an updated version of this paper, please see http://arxiv.org/abs/1406.5298

1

arX

iv:1

406.

5298

v2 [

cs.L

G]

31

Oct

201

4

Page 2: Semi-supervised Learning with Deep Generative Models · Semi-supervised Learning with Deep Generative Models Diederik P. Kingma , Danilo J. Rezende y, Shakir Mohamed , Max Welling

by training feed-forward classifiers with an additional penalty from an auto-encoder or other unsu-pervised embedding of the data (Ranzato and Szummer, 2008; Weston et al., 2012). The ManifoldTangent Classifier (MTC) (Rifai et al., 2011) trains contrastive auto-encoders (CAEs) to learn themanifold on which the data lies, followed by an instance of TangentProp to train a classifier that isapproximately invariant to local perturbations along the manifold. The idea of manifold learningusing graph-based methods has most recently been combined with kernel (SVM) methods in the At-las RBF model (Pitelis et al., 2014) and provides amongst most competitive performance currentlyavailable.

In this paper, we instead, choose to exploit the power of generative models, which recognise thesemi-supervised learning problem as a specialised missing data imputation task for the classifica-tion problem. Existing generative approaches based on models such as Gaussian mixture or hiddenMarkov models (Zhu, 2006), have not been very successful due to the need for a large numberof mixtures components or states to perform well. More recent solutions have used non-parametricdensity models, either based on trees (Kemp et al., 2003) or Gaussian processes (Adams and Ghahra-mani, 2009), but scalability and accurate inference for these approaches is still lacking. Variationalapproximations for semi-supervised clustering have also been explored previously (Li et al., 2009;Wang et al., 2009).

Thus, while a small set of generative approaches have been previously explored, a generalised andscalable probabilistic approach for semi-supervised learning is still lacking. It is this gap that weaddress through the following contributions:

• We describe a new framework for semi-supervised learning with generative models, em-ploying rich parametric density estimators formed by the fusion of probabilistic modellingand deep neural networks.

• We show for the first time how variational inference can be brought to bear upon the prob-lem of semi-supervised classification. In particular, we develop a stochastic variationalinference algorithm that allows for joint optimisation of both model and variational param-eters, and that is scalable to large datasets.

• We demonstrate the performance of our approach on a number of data sets providing state-of-the-art results on benchmark problems.

• We show qualitatively generative semi-supervised models learn to separate the data classes(content types) from the intra-class variabilities (styles), allowing in a very straightforwardfashion to simulate analogies of images on a variety of datasets.

2 Deep Generative Models for Semi-supervised LearningWe are faced with data that appear as pairs (X,Y) = {(x1, y1), . . . , (xN , yN )}, with the i-th ob-servation xi ∈ RD and the corresponding class label yi ∈ {1, . . . , L}. Observations will havecorresponding latent variables, which we denote by zi. We will omit the index i whenever it is clearthat we are referring to terms associated with a single data point. In semi-supervised classification,only a subset of the observations have corresponding class labels; we refer to the empirical distribu-tion over the labelled and unlabelled subsets as pl(x, y) and pu(x), respectively. We now developmodels for semi-supervised learning that exploit generative descriptions of the data to improve uponthe classification performance that would be obtained using the labelled data alone.

Latent-feature discriminative model (M1): A commonly used approach is to construct a modelthat provides an embedding or feature representation of the data. Using these features, a separateclassifier is thereafter trained. The embeddings allow for a clustering of related observations in alatent feature space that allows for accurate classification, even with a limited number of labels.Instead of a linear embedding, or features obtained from a regular auto-encoder, we construct adeep generative model of the data that is able to provide a more robust set of latent features. Thegenerative model we use is:

p(z) = N (z|0, I); pθ(x|z) = f(x; z,θ), (1)

where f(x; z,θ) is a suitable likelihood function (e.g., a Gaussian or Bernoulli distribution) whoseprobabilities are formed by a non-linear transformation, with parameters θ, of a set of latent vari-ables z. This non-linear transformation is essential to allow for higher moments of the data to becaptured by the density model, and we choose these non-linear functions to be deep neural networks.

2

Page 3: Semi-supervised Learning with Deep Generative Models · Semi-supervised Learning with Deep Generative Models Diederik P. Kingma , Danilo J. Rezende y, Shakir Mohamed , Max Welling

Approximate samples from the posterior distribution over the latent variables p(z|x) are used as fea-tures to train a classifier that predicts class labels y, such as a (transductive) SVM or multinomialregression. Using this approach, we can now perform classification in a lower dimensional spacesince we typically use latent variables whose dimensionality is much less than that of the obser-vations. These low dimensional embeddings should now also be more easily separable since wemake use of independent latent Gaussian posteriors whose parameters are formed by a sequence ofnon-linear transformations of the data. This simple approach results in improved performance forSVMs, and we demonstrate this in section 4.

Generative semi-supervised model (M2): We propose a probabilistic model that describes the dataas being generated by a latent class variable y in addition to a continuous latent variable z. The datais explained by the generative process:

p(y) = Cat(y|π); p(z) = N (z|0, I); pθ(x|y, z) = f(x; y, z,θ), (2)where Cat(y|π) is the multinomial distribution, the class labels y are treated as latent variables ifno class label is available and z are additional latent variables. These latent variables are marginallyindependent and allow us, in case of digit generation for example, to separate the class specifica-tion from the writing style of the digit. As before, f(x; y, z,θ) is a suitable likelihood function,e.g., a Bernoulli or Gaussian distribution, parameterised by a non-linear transformation of the latentvariables. In our experiments, we choose deep neural networks as this non-linear function. Sincemost labels y are unobserved, we integrate over the class of any unlabelled data during the infer-ence process, thus performing classification as inference. Predictions for any missing labels areobtained from the inferred posterior distribution pθ(y|x). This model can also be seen as a hybridcontinuous-discrete mixture model where the different mixture components share parameters.

Stacked generative semi-supervised model (M1+M2): We can combine these two approaches byfirst learning a new latent representation z1 using the generative model from M1, and subsequentlylearning a generative semi-supervised model M2, using embeddings from z1 instead of the raw datax. The result is a deep generative model with two layers of stochastic variables: pθ(x, y, z1, z2) =p(y)p(z2)pθ(z1|y, z2)pθ(x|z1), where the priors p(y) and p(z2) equal those of y and z above, andboth pθ(z1|y, z2) and pθ(x|z1) are parameterised as deep neural networks.

3 Scalable Variational Inference3.1 Lower Bound ObjectiveIn all our models, computation of the exact posterior distribution is intractable due to the nonlinear,non-conjugate dependencies between the random variables. To allow for tractable and scalableinference and parameter learning, we exploit recent advances in variational inference (Kingma andWelling, 2014; Rezende et al., 2014). For all the models described, we introduce a fixed-formdistribution qφ(z|x) with parameters φ that approximates the true posterior distribution p(z|x). Wethen follow the variational principle to derive a lower bound on the marginal likelihood of the model– this bound forms our objective function and ensures that our approximate posterior is as close aspossible to the true posterior.

We construct the approximate posterior distribution qφ(·) as an inference or recognition model,which has become a popular approach for efficient variational inference (Dayan, 2000; Kingma andWelling, 2014; Rezende et al., 2014; Stuhlmuller et al., 2013). Using an inference network, we avoidthe need to compute per data point variational parameters, but can instead compute a set of globalvariational parameters φ. This allows us to amortise the cost of inference by generalising betweenthe posterior estimates for all latent variables through the parameters of the inference network, andallows for fast inference at both training and testing time (unlike with VEM, in which we repeatthe generalized E-step optimisation for every test data point). An inference network is introducedfor all latent variables, and we parameterise them as deep neural networks whose outputs form theparameters of the distribution qφ(·). For the latent-feature discriminative model (M1), we use aGaussian inference network qφ(z|x) for the latent variable z. For the generative semi-supervisedmodel (M2), we introduce an inference model for each of the latent variables z and y, which we weassume has a factorised form qφ(z, y|x) = qφ(z|x)qφ(y|x), specified as Gaussian and multinomialdistributions respectively.

M1: qφ(z|x) = N (z|µφ(x),diag(σ2φ(x))), (3)

M2: qφ(z|y,x) = N (z|µφ(y,x),diag(σ2φ(x))); qφ(y|x) = Cat(y|πφ(x)), (4)

3

Page 4: Semi-supervised Learning with Deep Generative Models · Semi-supervised Learning with Deep Generative Models Diederik P. Kingma , Danilo J. Rezende y, Shakir Mohamed , Max Welling

where σφ(x) is a vector of standard deviations, πφ(x) is a probability vector, and the functionsµφ(x), σφ(x) and πφ(x) are represented as MLPs.

3.1.1 Latent Feature Discriminative Model Objective

For this model, the variational bound J (x) on the marginal likelihood for a single data point is:

log pθ(x) ≥ Eqφ(z|x) [log pθ(x|z)]−KL[qφ(z|x)‖pθ(z)] = −J (x), (5)

The inference network qφ(z|x) (3) is used during training of the model using both the labelled andunlabelled data sets. This approximate posterior is then used as a feature extractor for the labelleddata set, and the features used for training the classifier.

3.1.2 Generative Semi-supervised Model Objective

For this model, we have two cases to consider. In the first case, the label corresponding to a datapoint is observed and the variational bound is a simple extension of equation (5):

log pθ(x, y)≥Eqφ(z|x,y) [log pθ(x|y, z) + log pθ(y) + log p(z)− log qφ(z|x, y)]=−L(x, y), (6)

For the case where the label is missing, it is treated as a latent variable over which we performposterior inference and the resulting bound for handling data points with an unobserved label y is:

log pθ(x) ≥ Eqφ(y,z|x) [log pθ(x|y, z) + log pθ(y) + log p(z)− log qφ(y, z|x)]

=∑

yqφ(y|x)(−L(x, y)) +H(qφ(y|x)) = −U(x). (7)

The bound on the marginal likelihood for the entire dataset is now:

J =∑

(x,y)∼plL(x, y) +

∑x∼pu

U(x) (8)

The distribution qφ(y|x) (4) for the missing labels has the form a discriminative classifier, andwe can use this knowledge to construct the best classifier possible as our inference model. Thisdistribution is also used at test time for predictions of any unseen data.

In the objective function (8), the label predictive distribution qφ(y|x) contributes only to the secondterm relating to the unlabelled data, which is an undesirable property if we wish to use this distribu-tion as a classifier. Ideally, all model and variational parameters should learn in all cases. To remedythis, we add a classification loss to (8), such that the distribution qφ(y|x) also learns from labelleddata. The extended objective function is:

J α = J + α · Epl(x,y) [− log qφ(y|x)] , (9)

where the hyper-parameter α controls the relative weight between generative and purely discrimina-tive learning. We use α = 0.1 ·N in all experiments. While we have obtained this objective functionby motivating the need for all model components to learn at all times, the objective 9 can also bederived directly using the variational principle by instead performing inference over the parametersπ of the categorical distribution, using a symmetric Dirichlet prior over these parameterss.

3.2 Optimisation

The bounds in equations (5) and (9) provide a unified objective function for optimisation of boththe parameters θ and φ of the generative and inference models, respectively. This optimisation canbe done jointly, without resort to the variational EM algorithm, by using deterministic reparameter-isations of the expectations in the objective function, combined with Monte Carlo approximation –referred to in previous work as stochastic gradient variational Bayes (SGVB) (Kingma and Welling,2014) or as stochastic backpropagation (Rezende et al., 2014). We describe the core strategy for thelatent-feature discriminative model M1, since the same computations are used for the generativesemi-supervised model.

When the prior p(z) is a spherical Gaussian distribution p(z) = N (z|0, I) and the variational distri-bution qφ(z|x) is also a Gaussian distribution as in (3), the KL term in equation (5) can be computed

4

Page 5: Semi-supervised Learning with Deep Generative Models · Semi-supervised Learning with Deep Generative Models Diederik P. Kingma , Danilo J. Rezende y, Shakir Mohamed , Max Welling

Algorithm 1 Learning in model M1

while generativeTraining() doD ← getRandomMiniBatch()zi ∼ qφ(zi|xi) ∀xi ∈ DJ ←

∑n J (xi)

(gθ,gφ)← (∂J∂θ ,∂J∂φ )

(θ,φ)← (θ,φ) + Γ(gθ,gφ)end whilewhile discriminativeTraining() doD ← getLabeledRandomMiniBatch()zi ∼ qφ(zi|xi) ∀{xi, yi} ∈ DtrainClassifier({zi, yi} )

end while

Algorithm 2 Learning in model M2

while training() doD ← getRandomMiniBatch()yi ∼ qφ(yi|xi) ∀{xi, yi} /∈ Ozi ∼ qφ(zi|yi,xi)J α ← eq. (9)(gθ,gφ)← (∂L

α

∂θ ,∂Lα∂φ )

(θ,φ)← (θ,φ) + Γ(gθ,gφ)end while

analytically and the log-likelihood term can be rewritten, using the location-scale transformation forthe Gaussian distribution, as:

Eqφ(z|x) [log pθ(x|z)] = EN (ε|0,I)[log pθ(x|µφ(x) + σφ(x)� ε)

], (10)

where � indicates the element-wise product. While the expectation (10) still cannot be solvedanalytically, its gradients with respect to the generative parameters θ and variational parameters φcan be efficiently computed as expectations of simple gradients:

∇{θ,φ}Eqφ(z|x) [log pθ(x|z)] = EN (ε|0,I)[∇{θ,φ} log pθ(x|µφ(x) + σφ(x)� ε)

]. (11)

The gradients of the loss (9) for model M2 can be computed by a direct application of the chain ruleand by noting that the conditional bound L(xn, y) contains the same type of terms as the loss (9).The gradients of the latter can then be efficiently estimated using (11) .

During optimization we use the estimated gradients in conjunction with standard stochastic gradient-based optimization methods such as SGD, RMSprop or AdaGrad (Duchi et al., 2010). This resultsin parameter updates of the form: (θt+1,φt+1) ← (θt,φt) + Γt(gtθ,g

tφ), where Γ is a diagonal

preconditioning matrix that adaptively scales the gradients for faster minimization. The training pro-cedure for models M1 and M2 are summarised in algorithms 1 and 2, respectively. Our experimentalresults were obtained using AdaGrad.

3.3 Computational Complexity

The overall algorithmic complexity of a single joint update of the parameters (θ,φ) for M1 using theestimator (11) is CM1 =MSCMLP where M is the minibatch size used , S is the number of samplesof the random variate ε, and CMLP is the cost of an evaluation of the MLPs in the conditionaldistributions pθ(x|z) and qφ(z|x). The cost CMLP is of the form O(KD2) where K is the totalnumber of layers and D is the average dimension of the layers of the MLPs in the model. TrainingM1 also requires training a supervised classifier, whose algorithmic complexity, if it is a neural net,it will have a complexity of the form CMLP .

The algorithmic complexity for M2 is of the form CM2 = LCM1, where L is the number of labelsand CM1 is the cost of evaluating the gradients of each conditional bound Jy(x), which is the sameas for M1. The stacked generative semi-supervised model has an algorithmic complexity of theform CM1 +CM2. But with the advantage that the cost CM2 is calculated in a low-dimensional space(formed by the latent variables of the model M1 that provides the embeddings).

These complexities make this approach extremely appealing, since they are no more expensive thanalternative approaches based on auto-encoder or neural models, which have the lowest computa-tional complexity amongst existing competitive approaches. In addition, our models are fully prob-abilistic, allowing for a wide range of inferential queries, which is not possible with many alternativeapproaches for semi-supervised learning.

5

Page 6: Semi-supervised Learning with Deep Generative Models · Semi-supervised Learning with Deep Generative Models Diederik P. Kingma , Danilo J. Rezende y, Shakir Mohamed , Max Welling

Table 1: Benchmark results of semi-supervised classification on MNIST with few labels.

N NN CNN TSVM CAE MTC AtlasRBF M1+TSVM M2 M1+M2100 25.81 22.98 16.81 13.47 12.03 8.10 (± 0.95) 11.82 (± 0.25) 11.97 (± 1.71) 3.33 (± 0.14)600 11.44 7.68 6.16 6.3 5.13 – 5.72 (± 0.049) 4.94 (± 0.13) 2.59 (± 0.05)1000 10.7 6.45 5.38 4.77 3.64 3.68 (± 0.12) 4.24 (± 0.07) 3.60 (± 0.56) 2.40 (± 0.02)3000 6.04 3.35 3.45 3.22 2.57 – 3.49 (± 0.04) 3.92 (± 0.63) 2.18 (± 0.04)

4 Experimental Results

Open source code, with which the most important results and figures can be reproduced, is avail-able at http://github.com/dpkingma/nips14-ssl. For the latest experimental results,please see http://arxiv.org/abs/1406.5298.

4.1 Benchmark ClassificationWe test performance on the standard MNIST digit classification benchmark. The data set for semi-supervised learning is created by splitting the 50,000 training points between a labelled and unla-belled set, and varying the size of the labelled from 100 to 3000. We ensure that all classes arebalanced when doing this, i.e. each class has the same number of labelled points. We create a num-ber of data sets using randomised sampling to confidence bounds for the mean performance underrepeated draws of data sets.

For model M1 we used a 50-dimensional latent variable z. The MLPs that form part of the generativeand inference models were constructed with two hidden layers, each with 600 hidden units, usingsoftplus log(1+ex) activation functions. On top, a transductive SVM (TSVM) was learned on valuesof z inferred with qφ(z|x). For model M2 we also used 50-dimensional z. In each experiment, theMLPs were constructed with one hidden layer, each with 500 hidden units and softplus activationfunctions. In case of SVHN and NORB, we found it helpful to pre-process the data with PCA.This makes the model one level deeper, and still optimizes a lower bound on the likelihood of theunprocessed data.

Table 1 shows classification results. We compare to a broad range of existing solutions in semi-supervised learning, in particular to classification using nearest neighbours (NN), support vectormachines on the labelled set (SVM), the transductive SVM (TSVM), and contractive auto-encoders(CAE). Some of the best results currently are obtained by the manifold tangent classifier (MTC)(Rifai et al., 2011) and the AtlasRBF method (Pitelis et al., 2014). Unlike the other models in thiscomparison, our models are fully probabilistic but have a cost in the same order as these alternatives.

Results: The latent-feature discriminative model (M1) performs better than other models basedon simple embeddings of the data, demonstrating the effectiveness of the latent space in providingrobust features that allow for easier classification. By combining these features with a classificationmechanism directly in the same model, as in the conditional generative model (M2), we are able toget similar results without a separate TSVM classifier.

However, by far the best results were obtained using the stack of models M1 and M2. This com-bined model provides accurate test-set predictions across all conditions, and easily outperforms thepreviously best methods. We also tested this deep generative model for supervised learning withall available labels, and obtain a test-set performance of 0.96%, which is among the best publishedresults for this permutation-invariant MNIST classification task.

4.2 Conditional GenerationThe conditional generative model can be used to explore the underlying structure of the data, whichwe demonstrate through two forms of analogical reasoning. Firstly, we demonstrate style and con-tent separation by fixing the class label y, and then varying the latent variables z over a range ofvalues. Figure 1 shows three MNIST classes in which, using a trained model with two latent vari-ables, and the 2D latent variable varied over a range from -5 to 5. In all cases, we see that nearbyregions of latent space correspond to similar writing styles, independent of the class; the left regionrepresents upright writing styles, while the right-side represents slanted styles.

As a second approach, we use a test image and pass it through the inference network to infer avalue of the latent variables corresponding to that image. We then fix the latent variables z to this

6

Page 7: Semi-supervised Learning with Deep Generative Models · Semi-supervised Learning with Deep Generative Models Diederik P. Kingma , Danilo J. Rezende y, Shakir Mohamed , Max Welling

(a) Handwriting styles for MNIST obtained by fixing the class label and varying the 2D latent variable z

(b) MNIST analogies (c) SVHN analogies

Figure 1: (a) Visualisation of handwriting styles learned by the model with 2D z-space. (b,c)Analogical reasoning with generative semi-supervised models using a high-dimensional z-space.The leftmost columns show images from the test set. The other columns show analogical fantasiesof x by the generative model, where the latent variable z of each row is set to the value inferred fromthe test-set image on the left by the inference network. Each column corresponds to a class label y.

Table 2: Semi-supervised classification onthe SVHN dataset with 1000 labels.

KNN TSVM M1+KNN M1+TSVM M1+M277.93 66.55 65.63 54.33 36.02

(± 0.08) (± 0.10) (± 0.15) (± 0.11) (± 0.10)

Table 3: Semi-supervised classification onthe NORB dataset with 1000 labels.

KNN TSVM M1+KNN M1+TSVM78.71 26.00 65.39 18.79

(± 0.02) (± 0.06) (± 0.09) (± 0.05)

value, vary the class label y, and simulate images from the generative model corresponding to thatcombination of z and y. This again demonstrate the disentanglement of style from class. Figure 1shows these analogical fantasies for the MNIST and SVHN datasets (Netzer et al., 2011). TheSVHN data set is a far more complex data set than MNIST, but the model is able to fix the style ofhouse number and vary the digit that appears in that style well. These generations represent the bestcurrent performance in simulation from generative models on these data sets.

The model used in this way also provides an alternative model to the stochastic feed-forward net-works (SFNN) described by Tang and Salakhutdinov (2013). The performance of our model sig-nificantly improves on SFNN, since instead of an inefficient Monte Carlo EM algorithm relying onimportance sampling, we are able to perform efficient joint inference that is easy to scale.

4.3 Image ClassificationWe demonstrate the performance of image classification on the SVHN, and NORB image data sets.Since no comparative results in the semi-supervised setting exists, we perform nearest-neighbourand TSVM classification with RBF kernels and compare performance on features generated byour latent-feature discriminative model to the original features. The results are presented in tables 2and 3, and we again demonstrate the effectiveness of our approach for semi-supervised classification.

7

Page 8: Semi-supervised Learning with Deep Generative Models · Semi-supervised Learning with Deep Generative Models Diederik P. Kingma , Danilo J. Rezende y, Shakir Mohamed , Max Welling

4.4 Optimization details

The parameters were initialized by sampling randomly from N (0, 0.0012I), except for the bias pa-rameters which were initialized as 0. The objectives were optimized using minibatch gradient ascentuntil convergence, using a variant of RMSProp with momentum and initialization bias correction, aconstant learning rate of 0.0003, first moment decay (momentum) of 0.1, and second moment decayof 0.001. For MNIST experiments, minibatches for training were generated by treating normalisedpixel intensities of the images as Bernoulli probabilities and sampling binary images from this dis-tribution. In the M2 model, a weight decay was used corresponding to a prior of (θ,φ) ∼ N (0, I).

5 Discussion and ConclusionThe approximate inference methods introduced here can be easily extended to the model’s parame-ters, harnessing the full power of variational learning. Such an extension also provides a principledground for performing model selection. Efficient model selection is particularly important when theamount of available data is not large, such as in semi-supervised learning.

For image classification tasks, one area of interest is to combine such methods with convolutionalneural networks that form the gold-standard for current supervised classification methods. Since allthe components of our model are parametrised by neural networks we can readily exploit convolu-tional or more general locally-connected architectures – and forms a promising avenue for futureexploration.

A limitation of the models we have presented is that they scale linearly in the number of classesin the data sets. Having to re-evaluate the generative likelihood for each class during training is anexpensive operation. Potential reduction of the number of evaluations could be achieved by using atruncation of the posterior mass. For instance we could combine our method with the truncation al-gorithm suggested by Pal et al. (2005), or by using mechanisms such as error-correcting output codes(Dietterich and Bakiri, 1995). The extension of our model to multi-label classification problems thatis essential for image-tagging is also possible, but requires similar approximations to reduce thenumber of likelihood-evaluations per class.

We have developed new models for semi-supervised learning that allow us to improve the quality ofprediction by exploiting information in the data density using generative models. We have developedan efficient variational optimisation algorithm for approximate Bayesian inference in these modelsand demonstrated that they are amongst the most competitive models currently available for semi-supervised learning. We hope that these results stimulate the development of even more powerfulsemi-supervised classification methods based on generative models, of which there remains muchscope.

Acknowledgements. We are grateful for feedback from the reviewers. We would also like tothank the SURFFoundation for the use of the Dutch national e-infrastructure for a significant part ofthe experiments.

8

Page 9: Semi-supervised Learning with Deep Generative Models · Semi-supervised Learning with Deep Generative Models Diederik P. Kingma , Danilo J. Rezende y, Shakir Mohamed , Max Welling

ReferencesAdams, R. P. and Ghahramani, Z. (2009). Archipelago: nonparametric Bayesian semi-supervised learning. In

Proceedings of the International Conference on Machine Learning (ICML).Blum, A., Lafferty, J., Rwebangira, M. R., and Reddy, R. (2004). Semi-supervised learning using randomized

mincuts. In Proceedings of the International Conference on Machine Learning (ICML).Dayan, P. (2000). Helmholtz machines and wake-sleep learning. Handbook of Brain Theory and Neural

Network. MIT Press, Cambridge, MA, 44(0).Dietterich, T. G. and Bakiri, G. (1995). Solving multiclass learning problems via error-correcting output codes.

arXiv preprint cs/9501101.Duchi, J., Hazan, E., and Singer, Y. (2010). Adaptive subgradient methods for online learning and stochastic

optimization. Journal of Machine Learning Research, 12:2121–2159.Fergus, R., Weiss, Y., and Torralba, A. (2009). Semi-supervised learning in gigantic image collections. In

Advances in Neural Information Processing Systems (NIPS).Joachims, T. (1999). Transductive inference for text classification using support vector machines. In Proceeding

of the International Conference on Machine Learning (ICML), volume 99, pages 200–209.Kemp, C., Griffiths, T. L., Stromsten, S., and Tenenbaum, J. B. (2003). Semi-supervised learning with trees. In

Advances in Neural Information Processing Systems (NIPS).Kingma, D. P. and Welling, M. (2014). Auto-encoding variational Bayes. In Proceedings of the International

Conference on Learning Representations (ICLR).Li, P., Ying, Y., and Campbell, C. (2009). A variational approach to semi-supervised clustering. In Proceedings

of the European Symposium on Artificial Neural Networks (ESANN), pages 11 – 16.Liang, P. (2005). Semi-supervised learning for natural language. PhD thesis, Massachusetts Institute of Tech-

nology.Liu, Y. and Kirchhoff, K. (2013). Graph-based semi-supervised learning for phone and segment classification.

In Proceedings of Interspeech.Netzer, Y., Wang, T., Coates, A., Bissacco, A., Wu, B., and Ng, A. Y. (2011). Reading digits in natural images

with unsupervised feature learning. In NIPS workshop on deep learning and unsupervised feature learning.Pal, C., Sutton, C., and McCallum, A. (2005). Fast inference and learning with sparse belief propagation. In

Advances in Neural Information Processing Systems (NIPS).Pitelis, N., Russell, C., and Agapito, L. (2014). Semi-supervised learning using an unsupervised atlas. In

Proceddings of the European Conference on Machine Learning (ECML), volume LNCS 8725, pages 565 –580.

Ranzato, M. and Szummer, M. (2008). Semi-supervised learning of compact document representations withdeep networks. In Proceedings of the 25th International Conference on Machine Learning (ICML), pages792–799.

Rezende, D. J., Mohamed, S., and Wierstra, D. (2014). Stochastic backpropagation and approximate inferencein deep generative models. In Proceedings of the International Conference on Machine Learning (ICML),volume 32 of JMLR W&CP.

Rifai, S., Dauphin, Y., Vincent, P., Bengio, Y., and Muller, X. (2011). The manifold tangent classifier. InAdvances in Neural Information Processing Systems (NIPS), pages 2294–2302.

Rosenberg, C., Hebert, M., and Schneiderman, H. (2005). Semi-supervised self-training of object de-tection models. In Proceedings of the Seventh IEEE Workshops on Application of Computer Vision(WACV/MOTION’05).

Shi, M. and Zhang, B. (2011). Semi-supervised learning improves gene expression-based prediction of cancerrecurrence. Bioinformatics, 27(21):3017–3023.

Stuhlmuller, A., Taylor, J., and Goodman, N. (2013). Learning stochastic inverses. In Advances in neuralinformation processing systems, pages 3048–3056.

Tang, Y. and Salakhutdinov, R. (2013). Learning stochastic feedforward neural networks. In Advances inNeural Information Processing Systems (NIPS), pages 530–538.

Wang, Y., Haffari, G., Wang, S., and Mori, G. (2009). A rate distortion approach for semi-supervised condi-tional random fields. In Advances in Neural Information Processing Systems (NIPS), pages 2008–2016.

Weston, J., Ratle, F., Mobahi, H., and Collobert, R. (2012). Deep learning via semi-supervised embedding. InNeural Networks: Tricks of the Trade, pages 639–655. Springer.

Zhu, X. (2006). Semi-supervised learning literature survey. Technical report, Computer Science, University ofWisconsin-Madison.

Zhu, X., Ghahramani, Z., Lafferty, J., et al. (2003). Semi-supervised learning using Gaussian fields and har-monic functions. In Proceddings of the International Conference on Machine Learning (ICML), volume 3,pages 912–919.

9