Structured Inference Networks for Nonlinear State Space Models

Gaussian state space models have been used for decades as generative models of sequential data. They admit an intuitive probabilistic interpretation, have a simple functional form, and enjoy widespread adoption. We introduce a unified algorithm to ef…

Authors: Rahul G. Krishnan, Uri Shalit, David Sontag

Structured Inference Networks for Nonlinear State Space Models
Structur ed Infer ence Networks f or Nonlinear State Space Models Rahul G. Krishnan, Uri Shalit, Da vid Sontag Courant Institute of Mathematical Sciences, New Y ork Univ ersity {rahul, shalit, dsontag}@cs.nyu.edu Abstract Gaussian state space models hav e been used for decades as generativ e models of sequential data. They admit an intuiti v e probabilistic interpretation, hav e a simple functional form, and enjoy widespread adoption. W e introduce a unified algorithm to efficiently learn a broad class of linear and non-linear state space models, including variants where the emission and tran- sition distributions are modeled by deep neural networks. Our learning algorithm simultaneously learns a compiled inference network and the generative model, leveraging a structured variational approximation parameterized by recurrent neural networks to mimic the posterior distribution. W e apply the learning algorithm to both synthetic and real-w orld datasets, demonstrating its scalability and v ersatility . W e find that using the structured approximation to the posterior results in models with significantly higher held-out likelihood. 1 Introduction Models of sequence data such as hidden Markov models (HMMs) and recurrent neural netw orks (RNNs) are widely used in machine translation, speech recognition, and compu- tational biology . Linear and non-linear Gaussian state space models (GSSMs, Fig. 1) are used in applications including robotic planning and missile tracking. Ho we ver , despite huge progress ov er the last decade, efficient learning of non-linear models from complex high dimensional time-series remains a major challenge. Our paper proposes a unified learning algorithm for a broad class of GSSMs, and we introduce an inference procedure that scales easily to high dimensional data, compiling approximate (and where feasible, e xact) in- ference into the parameters of a neural network. In engineering and control, the parametric form of the GSSM model is often known, with typically a few spe- cific parameters that need to be fit to data. The most commonly used approaches for these types of learning and inference problems are often computationally demand- ing, e.g. dual e xtended Kalman filter (W an and Nelson 1996), e xpectation maximization (Briegel and T resp 1999; Ghahramani and Roweis 1999) or particle filters (Schön, W ills, and Ninness 2011). Our compiled inference algorithm can easily deal with high-dimensions both in the observed Copyright c  2017, Association for the Adv ancement of Artificial Intelligence (www .aaai.org). All rights reserved. and the latent spaces, without compromising the quality of inference and learning. When the parametric form of the model is unknown, we propose learning deep Markov models (DMM), a class of generativ e models where classic linear emission and tran- sition distributions are replaced with complex multi-layer perceptrons (MLPs). These are GSSMs that retain the Marko- vian structure of HMMs, b ut le verage the representational power of deep neural networks to model complex high di- mensional data. If one augments a DMM model such as the one presented in Fig. 1 with edges from the observations x t to the latent states of the follo wing time step z t +1 , then the DMM can be seen to be similar to, though more restric- tiv e than, stochastic RNNs (Bayer and Osendorfer 2014) and variational RNNs (Chung et al. 2015). Our learning algorithm performs stochastic gradient as- cent on a variational lower bound of the likelihood. In- stead of introducing variational parameters for each data point, we compile the inference procedure at the same time as learning the generativ e model. This idea was originally used in the wak e-sleep algorithm for unsupervised learning (Hinton et al . 1995), and has since led to state-of-the-art results for unsupervised learning of deep generative mod- els (Kingma and W elling 2014; Mnih and Gregor 2014; Rezende, Mohamed, and W ierstra 2014). Specifically , we introduce a ne w family of structur ed infer - ence networks , parameterized by recurrent neural netw orks, and ev aluate their ef fecti v eness in three scenarios: (1) when the generati ve model is known and fixed, (2) in parameter estimation when the functional form of the model is kno wn and (3) for learning deep Markov models. By looking at the structure of the true posterior, we show both theoretically and empirically that inference for a latent state should be performed using information from its futur e , as opposed to recent work which performed inference using only infor- mation from the past (Chung et al . 2015; Gan et al . 2015; Gregor et al . 2015), and that a structured variational approxi- mation outperforms mean-field based approximations. Our approach may easily be adapted to learning more general generativ e models, for e xample models with edges from ob- servations to latent states. Finally , we learn a DMM on a polyphonic music dataset and on a dataset of electronic health records (a complex high dimensional setting with missing data). W e use the model z 1 z 2 . . . x 1 x 2 z 1 z 2 . . . x 1 x 2 d d d d h 1 h 2 . . . x 1 x 2 Figure 1: Generative Models of Sequential Data: ( T op Left ) Hidden Markov Model (HMM), ( T op Right ) Deep Markov Model (DMM)  denotes the neural networks used in DMMs for the emis- sion and transition functions. ( Bottom ) Recurrent Neural Network (RNN), ♦ denotes a deterministic intermediate representation. Code for learning DMMs and reproducing our results may be found at: github.com/clinicalml/structuredinference learned on health records to ask queries such as “what would hav e happened to patients had the y not recei v ed treatment”, and show that our model correctly identifies the w ay certain medications affect a patient’ s health. Related W ork: Learning GSSMs with MLPs for the tran- sition distribution was considered by (Raiko and T ornio 2009). The y approximate the posterior with non-linear dy- namic factor analysis (V alpola and Karhunen 2002), which scales quadratically with the observ ed dimension and is im- practical for large-scale learning. Recent work has considered v ariational learning of time- series data using structured inference or recognition networks. Archer et al . propose using a Gaussian approximation to the posterior distribution with a block-tridiagonal in verse cov ariance. Johnson et al . use a conditional random field as the inference network for time-series models. Concurrent to our o wn work, Fraccaro et al . also learn sequential generati v e models using structured inference networks parameterized by recurrent neural networks. Bayer and Osendorfer and Fabius and van Amersfoort cre- ate a stochastic variant of RNNs by making the hidden state of the RNN at ev ery time step be a function of independently sampled latent v ariables. Chung et al. apply a similar model to speech data, sharing parameters between the RNNs for the generativ e model and the inference network. Gan et al . learn a model with discrete random v ariables, using a structured inference network that only considers information from the past, similar to Chung et al . and Gre gor et al . ’ s models. In contrast to these works, we use information from the future within a structured inference network, which we sho w to be preferable both theoretically and practically . Additionally , we systematically ev aluate the impact of the dif ferent v ariational approximations on learning. W atter et al . construct a first-order Marko v model using in- ference networks. Ho we ver , their learning algorithm is based on data tuples over consecuti ve time steps. This makes the strong assumption that the posterior distribution can be recov- ered based on observations at the current and ne xt time-step. As we show , for generativ e models like the one in Fig. 1, the posterior distribution at an y time step is a function of all future (and past) observations. 2 Background Gaussian State Space Models: W e consider both inference and learning in a class of latent v ariable models gi v en by: W e denote by z t a vector valued latent v ariable and by x t a vector valued observ ation. A sequence of such latent v ariables and observations is denoted ~ z , ~ x respectiv ely . z t ∼ N ( G α ( z t − 1 , ∆ t ) , S β ( z t − 1 , ∆ t )) (T ransition) (1) x t ∼ Π( F κ ( z t )) (Emission) (2) W e assume that the distrib ution of the latent states is a mul- tiv ariate Gaussian with a mean and cov ariance which are differentiable functions of the previous latent state and ∆ t (the time elapsed of time between t − 1 and t ). The multi v ari- ate observations x t are distributed according to a distribution Π (e.g., independent Bernoullis if the data is binary) whose parameters are a function of the corresponding latent state z t . Collectiv ely , we denote by θ = { α, β , κ } the parameters of the generativ e model. Eq. 1 subsumes a lar ge family of linear and non-linear Gaussian state space models. For example, by setting G α ( z t − 1 ) = G t z t − 1 , S β = Σ t , F κ = F t z t , where G t , Σ t and F t are matrices, we obtain linear state space models. The functional forms and initial parameters for G α , S β , F κ may be pre-specified. V ariational Learning: Using recent advances in vari- ational inference we optimize a variational lower bound on the data log-likelihood. The key technical innov ation is the introduction of an inference network or recognition network (Hinton et al . 1995; Kingma and W elling 2014; Mnih and Gregor 2014; Rezende, Mohamed, and W ierstra 2014), a neural network which approximates the intractable posterior . This is a parametric conditional distribution that is optimized to perform inference. Throughout this paper we will use θ to denote the parameters of the generati v e model, and φ to denote the parameters of the inference network. For the remainder of this section, we consider learning in a Bayesian network whose joint distrib ution factorizes as: p ( x, z ) = p θ ( z ) p θ ( x | z ) . The posterior distribution p θ ( z | x ) is typically intractable. Using the well-kno wn v ariational princi- ple, we posit an approximate posterior distribution q φ ( z | x ) to obtain the following lo wer bound on the marginal lik elihood: log p θ ( x ) ≥ E q φ ( z | x ) [log p θ ( x | z )] − KL( q φ ( z | x ) || p θ ( z ) ) , (3) where the inequality is by Jensen’ s inequality . Kingma and W elling; Rezende, Mohamed, and W ierstra use a neural net (with parameters φ ) to parameterize q φ . The challenge in the resulting optimization problem is that the lower bound in Eq. 3 includes an e xpectation w .r .t. q φ , which implicitly de- pends on the network parameters φ . When using a Gaussian variational approximation q φ ( z | x ) ∼ N ( µ φ ( x ) , Σ φ ( x )) , where µ φ ( x ) , Σ φ ( x ) are parametric functions of the obser- vation x , this dif ficulty is overcome by using stochastic backpr opagation : a simple transformation allo ws one to ob- tain unbiased Monte Carlo estimates of the gradients of E q φ ( z | x ) [log p θ ( x | z )] with respect to φ . The KL term in Eq. 3 can be estimated similarly since it is also an expectation. When the prior p θ ( z ) is Normally distributed, the KL and its gradients may be obtained analytically . 3 A F actorized V ariational Lower Bound W e le verage stochastic backpropagation to learn generativ e models giv en by Eq. 1, corresponding to the graphical model in Fig. 1. Our insight is that for the purpose of inference, we can use the Marko v properties of the generati ve model to guide us in deri ving a structured approximation to the posterior . Specifically , the posterior f actorizes as: p ( ~ z | ~ x ) = p ( z 1 | ~ x ) T Y t =2 p ( z t | z t − 1 , x t , . . . , x T ) . (4) T o see this, use the independence statements implied by the graphical model in Fig. 1 to note that p ( ~ z | ~ x ) , the true poste- rior , factorizes as: p ( ~ z | ~ x ) = p ( z 1 | ~ x ) T Y t =2 p ( z t | z t − 1 , ~ x ) Now , we notice that z t ⊥ ⊥ x 1 , . . . , x t − 1 | z t − 1 , yielding the desired result. The significance of Eq. 4 is that it yields insight into the structure of the e xact posterior for the class of models laid out in Fig. 1. W e directly mimic the structure of the posterior with the following f actorization of the v ariational approximation: q φ ( ~ z | ~ x ) = q φ ( z 1 | x 1 , . . . , x T ) T Y t =2 q φ ( z t | z t − 1 , x t , . . . , x T ) (5) s.t. q φ ( z t | z t − 1 , x t , . . . , x T ) ∼ N ( µ φ ( z t − 1 , x t , . . . , x T ) , Σ φ ( z t − 1 , x t , . . . , x T )) where µ φ and Σ φ are functions parameterized by neural nets. Although q φ has the option to condition on all information across time, Eq. 4 suggests that in fact it suf fices to condition on information from the future and the pre vious latent state. The previous latent state serves as a summary statistic for information from the past. Exact Inference: W e can match the factorization of the true posterior using the inference network b ut using a Gaussian variational approximation for the approximate posterior ov er each latent v ariable (as we do) limits the e xpressi vity of the inferential model, e xcept for the case of linear dynamical sys- tems where the posterior distribution is Normally distributed. Ho wev er , one could augment our proposed inference network with recent innov ations that impro ve the v ariational approxi- mation to allo w for multi-modality (Rezende and Mohamed 2015; T ran, Ranganath, and Blei 2016). Such modifications could yield black-box methods for exact inference in time- series models, which we leav e for future work. Deriving a V ariational Lower Bound: For a generative model (with parameters θ ) and an inference netw ork (with parameters φ ), we are interested in max θ log p θ ( ~ x ) . For ease of exposition, we instantiate the deri v ation of the variational bound for a single data point ~ x though we learn θ , φ from a corpus. The lower bound in Eq. 3 has an analytic form of the KL term only for the simplest of transition models G α , S β between z t − 1 and z t (Eq. 1). One could estimate the gradient of the KL term by sampling from the v ariational model, b ut that results in high variance estimates and gradients. W e use a different factorization of the KL term (obtained by using the prior distribution over latent variables), leading to the variational lo wer bound we use as our objecti v e function: L ( ~ x ; ( θ , φ )) = T X t =1 E q φ ( z t | ~ x ) [log p θ ( x t | z t )] (6) − KL( q φ ( z 1 | ~ x ) || p θ ( z 1 )) − T X t =2 E q φ ( z t − 1 | ~ x ) [KL( q φ ( z t | z t − 1 , ~ x ) || p θ ( z t | z t − 1 ))] . The key point is the resulting objectiv e function has more stable analytic gradients. W ithout the factorization of the KL di v ergence in Eq. 6, we would ha ve to estimate KL( q ( ~ z | ~ x ) || p ( ~ z )) via Monte-Carlo sampling, since it has no analytic form. In contrast, in Eq. 6 the indi vidual KL terms do have analytic forms. A detailed deriv ation of the bound and the factorization of the KL di v ergence is detailed in the supplemental material. Learning with Gradient Descent: The objecti ve in Eq. 6 is dif ferentiable in the parameters of the model ( θ , φ ). If the generativ e model θ is fixed, we perform gradient ascent of Eq. 6 in φ . Otherwise, we perform gradient ascent in both φ and θ . W e use stochastic backpropagation (Kingma and W elling 2014; Rezende, Mohamed, and W ierstra 2014) for estimating the gradient w .r .t. φ . Note that the expectations are only tak en with respect to the variables z t − 1 , z t , which are the sufficient statistics of the Mark ov model. F or the KL terms in Eq. 6, we use the fact that the prior p θ ( z t | z t − 1 ) and the variational approximation to the posterior q φ ( z t | z t − 1 , ~ x ) are both Normally distributed, and hence their KL di ver gence may be estimated analytically . Algorithm 1 Learning a DMM with stochastic gradient de- scent: W e use a single sample from the recognition network during learning to ev aluate e xpectations in the bound. W e aggregate gradi- ents across mini-batches. Inputs : Dataset D Inference Model: q φ ( ~ z | ~ x ) Generativ e Model: p θ ( ~ x | ~ z ) , p θ ( ~ z ) while notC onv er g ed () do 1. Sample datapoint: ~ x ∼ D 2. Estimate posterior parameters (Evaluate µ φ , Σ φ ) 3. Sample ˆ ~ z ∼ q φ ( ~ z | ~ x ) 4. Estimate conditional likelihood: p θ ( ~ x | ˆ ~ z ) & KL 5. Evaluate L ( ~ x ; ( θ , φ )) 6. Estimate MC approx. to ∇ θ L 7. Estimate MC approx. to ∇ φ L (Use stochastic backpropagation to mo ve gradients with respect to q φ inside expectation) 8. Update θ , φ using ADAM (Kingma and Ba 2015) end while T able 1: Inference Netw orks: BRNN refers to a Bidirectional RNN and comb .fxn is shorthand for combiner function. Inference Network V ariational Approximation for z t Implemented W ith MF-LR q ( z t | x 1 , . . . x T ) BRNN MF-L q ( z t | x 1 , . . . x t ) RNN ST -L q ( z t | z t − 1 , x 1 , . . . x t ) RNN & comb .fxn DKS q ( z t | z t − 1 , x t , . . . x T ) RNN & comb .fxn ST -LR q ( z t | z t − 1 , x 1 , . . . x T ) BRNN & comb .fxn Algorithm 1 depicts an ov ervie w of the learning algorithm. W e outline the algorithm for a mini-batch of size one, but in practice gradients are averaged across stochastically sampled mini-batches of the training set. W e take a gradient step in θ and φ , typically with an adaptiv e learning rate such as (Kingma and Ba 2015). 4 Structured Inference Networks W e no w detail ho w we construct the v ariational approxima- tion q φ , and specifically how we model the mean and diagonal cov ariance functions µ and Σ using recurrent neural netw orks (RNNs). Since our implementation only models the diagonal of the co v ariance matrix (the vector v alued v ariances), we denote this as σ 2 rather than Σ . This parameterization cannot in general be expected to be equal to p θ ( ~ z | ~ x ) , b ut in many cases is a reasonable approximation. W e use RNNs due to their ability to scale well to large datasets. T able 1 details the dif ferent choices for inference net- works that we e v aluate. The Deep Kalman Smoother DKS corresponds exactly to the functional form suggested by Eq. 4, and is our proposed variational approximation. The DKS smoothes information from the past ( z t ) and future ( x t , . . . x T ) to form the approximate posterior distribution. W e also e v aluate other possibilities for the v ariational mod- els (inference networks) q φ : two are mean-field models (de- noted MF ) and two are structured models (denoted ST ). The y are distinguished by whether the y use information from the past (denoted L , for left), the future (denoted R , for right), or both (denoted LR ). See Fig. 2 for an illustration of two of these methods. Each conditions on a dif ferent subset of the observ ations to summarize information in the input sequence ~ x . DKS corresponds to ST -R . The hidden states of the RNN parameterize the varia- tional distribution, which go through what we call the “com- biner function”. W e obtain the mean µ t and diagonal co- variance σ 2 t for the approximate posterior at each time-step in a manner akin to Gaussian belief propagation. Specifi- cally , we interpret the hidden states of the forward and back- ward RNNs as parameterizing the mean and v ariance of two Gaussian-distributed “messages” summarizing the observ a- tions from the past and the future, respectively . W e then mul- tiply these two Gaussians, performing a v ariance-weighted av erage of the means. All operations should be understood to be performed element-wise on the corresponding v ectors. h left t , h right t are the hidden states of the RNNs that run from the past and the future respectiv ely (see Fig. 2). Combiner Function for Mean Field Appr oximations: For the MF-LR inference network, the mean µ t and diago- nal v ariances σ 2 t of the v ariational distribution q φ ( z t | ~ x ) are x 1 x 2 x 3 h left 1 h left 2 h left 3 Forward RNN h right 1 h right 2 h right 3 Backward RNN ( µ 1 , Σ 1 ) ( µ 2 , Σ 2 ) ( µ 3 , Σ 3 ) Combiner function (a) (a) (a) ˆ z 1 ˆ z 2 ˆ z 3 ~ 0 Figure 2: Structured Infer ence Networks: MF-LR and ST -LR variational approximations for a sequence of length 3 , using a bi- directional recurrent neural net (BRNN). The BRNN takes as input the sequence ( x 1 , . . . x 3 ) , and through a series of non-linearities denoted by the blue arro ws it forms a sequence of hidden states summarizing information from the left and right ( h left t and h right t ) re- spectiv ely . Then through a further sequence of non-linearities which we call the “combiner function” (marked (a) above), and denoted by the red arro ws, it outputs two v ectors µ and Σ , parameterizing the mean and diagonal cov ariance of q φ ( z t | z t − 1 , ~ x ) of Eq. 5. Samples ˆ z t are drawn from q φ ( z t | z t − 1 , ~ x ) , as indicated by the black dashed arrows. For the structured variational models ST -LR , the samples ˆ z t are fed into the computation of µ t +1 and Σ t +1 , as indicated by the red arro ws with the label (a). The mean-field model does not ha ve these arro ws, and therefore computes q φ ( z t | ~ x ) . W e use ˆ z 0 = ~ 0 . The inference network for DKS (ST -R) is structured like that of ST -LR except without the RNN from the past. predicted using the output of the RNN (not conditioned on z t − 1 ) as follows, where softplus ( x ) = log(1 + exp( x )) : µ r = W right µ r h right t + b right µ r ; σ 2 r = softplus ( W right σ 2 r h right t + b right σ 2 r ) µ l = W left µ l h left t + b left µ l ; σ 2 l = softplus ( W left σ 2 l h left t + b left σ 2 l ) µ t = µ r σ 2 l + µ l σ 2 r σ 2 r + σ 2 l ; σ 2 t = σ 2 r σ 2 l σ 2 r + σ 2 l Combiner Function for Structured Appr oximations: The combiner functions for the structured approximations are implemented as: (F or ST -LR ) h combined = 1 3 ( tanh ( W z t − 1 + b ) + h left t + h right t ) (F or DKS ) h combined = 1 2 ( tanh ( W z t − 1 + b ) + h right t ) (P osterior Means and Covariances) µ t = W µ h combined + b µ σ 2 t = softplus ( W σ 2 h combined + b σ 2 ) The combiner function uses the tanh non-linearity from z t − 1 to approximate the transition function (alternativ ely , one could share parameters with the generativ e model), and here we use a simple weighting between the components. Relationship to Related W ork: Archer et al . ; Gao et al . use q ( ~ z | ~ x ) = Q t q ( z t | z t − 1 , ~ x ) where q ( z t | z t − 1 , ~ x ) = N ( µ ( x t ) , Σ( z t − 1 , x t , x t − 1 )) . The key dif ference from our approach is that this parameterization (in particular , condi- tioning the posterior means only on x t ) does not account for the information from the future rele v ant to the approximate posterior distribution for z t . Johnson et al . interleav e predicting the local variational parameters of the graphical model (using an inference net- work) with steps of message passing inference. A key dif- ference between our approach and theirs is that we rely on the structured inference network to predict the optimal local variational parameters directly . In contrast, in Johnson et al . , any suboptimalities in the initial local variational parameters may be o vercome by the subsequent steps of optimization albeit at additional computational cost. Chung et al . propose the V ariational RNN (VRNN) in which Gaussian noise is introduced at each time-step of a RNN. Chung et al . use an inference network that shares parameters with the generativ e model and only uses infor- mation from the past. If one views the noise variables and the hidden state of the RNN at time-step t together as z t , then a factorization similar to Eq. 6 can be sho wn to hold, although the KL term would no longer have an analytic form since p θ ( z t | z t − 1 , x t − 1 ) would not be Normally distributed. Nonetheless, our same structured inference networks (i.e. using an RNN to summarize observ ations from the future) could be used to improv e the tightness of the variational lower bound, and our empirical results suggest that it w ould result in better learned models. 5 Deep Marko v Models Follo wing (Raiko et al . 2006), we apply the ideas of deep learning to non-linear continuous state space models. When the transition and emission function ha ve an unkno wn func- tional form, we parameterize G α , S β , F κ from Eq. 1 with deep neural networks. See Fig. 1 (right) for an illustration of the graphical model. Emission Function: W e parameterize the emission function F κ using a two-layer MLP (multi-layer per - ceptron), MLP ( x, NL 1 , NL 2 ) = NL 2 ( W 2 NL 1 ( W 1 x + b 1 ) + b 2 )) , where NL denotes non-linearities such as ReLU, sigmoid, or tanh units applied element-wise to the input vector . For modeling binary data, F κ ( z t ) = sigmoid ( W emission MLP ( z t , ReLU , ReLU ) + b emission ) param- eterizes the mean probabilities of independent Bernoullis. Gated T ransition Function: W e parameterize the transi- tion function from z t to z t +1 using a gated transition function inspired by Gated Recurrent Units (Chung et al . 2014), in- stead of an MLP . Gated recurrent units (GR Us) are a neural architecture that parameterizes the recurrence equation in the RNN with gating units to control the flow of information from one hidden state to the ne xt, conditioned on the observ a- tion. Unlike GRUs, in the DMM, the transition function is not conditional on any of the observ ations. All the information must be encoded in the completely stochastic latent state. T o achiev e this goal, we create a Gated T ransition Function. W e would like the model to ha ve the fle xibility to choose a linear transition for some dimensions while having a non-linear transitions for the others. W e adopt the following parameteri- zation, where I denotes the identity function and  denotes element-wise multiplication: g t = MLP ( z t − 1 , ReLU , sigmoid ) (Gating Unit) h t = MLP ( z t − 1 , ReLU , I ) (Pr oposed mean) (T r ansition Mean G α and S β ) µ t ( z t − 1 ) = (1 − g t )  ( W µ p z t − 1 + b µ p ) + g t  h t σ 2 t ( z t − 1 ) = softplus ( W σ 2 p ReLU ( h t ) + b σ 2 p ) Note that the mean and cov ariance functions both share the use of h t . In our experiments, we initialize W µ p to be the identity function and b µ p to 0 . The parameters of the emission and transition function form the set θ . 6 Evaluation Our models and learning algorithm are implemented in Theano (Theano Dev elopment T eam 2016). W e use Adam (Kingma and Ba 2015) with a learning rate of 0 . 0008 to train the DMM. Our code is av ailable at github.com/clinicalml/structuredinference . Datasets: W e ev aluate on three datasets. Synthetic: W e consider simple linear and non-linear GSSMs. T o train the inference networks we use N = 5000 datapoints of length T = 25 . W e consider both one and two dimensional systems for inference and parameter esti- mation. W e compare our results using the training value of the v ariational bound L ( ~ x ; ( θ , φ )) (Eq. 6) and the RMSE = q 1 N 1 T P N i =1 P T t =1 [ µ φ ( x i,t ) − z ∗ i,t ] 2 , where z ∗ correspond to the true underlying z ’ s that generated the data. P olyphonic Music: W e train DMMs on polyphonic music data (Boulanger -le wando wski, Bengio, and V incent 2012). An instance in the sequence comprises an 88-dimensional binary vector corresponding to the notes of a piano. W e learn for 2000 epochs and report results based on early stopping using the validation set. W e report held-out negativ e log- likelihood (NLL) in the format “a (b) {c}”. a is an importance sampling based estimate of the NLL (details in supplementary material); b = 1 P N i =1 T i P N i =1 −L ( ~ x ; θ , φ ) where T i is the length of sequence i . This is an upper bound on the NLL, which facilitates comparison to RNNs; TSBN (Gan et al . 2015) (in their code) report c = 1 N P N i =1 1 T i L ( ~ x ; θ , φ ) . W e compute this to facilitate comparison with their work. As in (Kaae Sønderby et al . 2016), we found annealing the KL div er gence in the v ariational bound ( L ( ~ x ; ( θ , φ )) ) from 0 to 1 ov er 5000 parameter updates got better results. Electr onic Health Recor ds (EHRs): The dataset comprises 5000 diabetic patients using data from a major health insur - ance provider . The observations of interest are: A1c level (hemoglobin A1c, a protein for which a high le v el indicates that the patient is diabetic) and glucose (blood sugar). W e bin glucose into quantiles and A1c into clinically meaningful bins. The observations also include age, gender and ICD-9 diagnosis codes for co-morbidities of diabetes such as conges- tiv e heart f ailure, chronic kidne y disease and obesity . There are 48 binary observ ations for a patient at ev ery time-step. W e group each patient’ s data (over 4 years) into three month intervals, yielding a sequence of length 18 . 6.1 Synthetic Data Compiling Exact Inference: W e seek to understand whether inference networks can accurately compile exact posterior inference into the network parameters φ for linear GSSMs when exact inference is feasible. F or this experiment we optimize Eq. 6 ov er φ , while θ is fixed to a synthetic distribution gi v en by a one-dimensional GSSM. W e compare results obtained by the v arious approximations we propose to those obtained by an implementation of Kalman smooth- ing (Duckworth 2016) which performs e xact infer ence . Fig. 3 (top and middle) depicts our results. The proposed DKS (i.e., ST -R ) and ST -LR outperform the mean-field based v ari- ational method MF-L that only looks at information from the past. MF-LR , ho we ver , is often able to catch up when it comes to RMSE, highlighting the role that information from the future plays when performing posterior inference, as is e vident in the posterior factorization in Eq. 4. Both DKS and ST -LR con v erge to the RMSE of the exact Smoothed KF , and moreov er their lo wer bound on the likelihood becomes tight. Appr oximate Inference and Parameter Estimation: Here, we e xperiment with applying the inference netw orks to synthetic non-linear generativ e models as well as using DKS for learning a subset of parameters within a fixed generati v e model. On synthetic non-linear datasets (see supplemental material) we find, similarly , that the structured variational approximations are capable of matching the performance of inference using a smoothed Unscented Kalman Filter (W an, V an Der Merwe, and others 2000) on held-out data. Finally , Fig. 4 illustrates a toy instance where we successfully per- form parameter estimation in a synthetic, two-dimensional, non-linear GSSM. 6.2 Polyphonic Music Mean-Field vs Structured Inference Networks: T able 2 shows the results of learning a DMM on the polyphonic mu- sic dataset using MF-LR , ST -L , DKS and ST -LR . ST -L is a structured variational approximation that only considers information from the past and, up to implementation details, is comparable to the one used in (Gregor et al . 2015). Com- paring the negati v e log-likelihoods of the learned models, we see that the looseness in the v ariational bound (which we first observed in the synthetic setting in Fig. 3 top right) significantly af fects the ability to learn. ST -LR and DKS sub- stantially outperform MF-LR and ST -L . This adds credence to the idea that by taking into consideration the factorization of the posterior , one can perform better inference and, con- sequently , learning, in real-world, high dimensional settings. Note that the DKS network has half the parameters of the ST -LR and MF-LR networks. A Generalization of the DMM: T o display the efficacy of our inference algorithm to model variants beyond first- order Markov Models, we further augment the DMM with 0 50 100 150 200 250 300 350 Ep o c hs 1 2 3 4 5 6 T rain RMSE ST-LR MF-LR ST-L ST-R MF-L KF [Exact] 0 50 100 150 200 250 300 350 Ep o c hs 3 . 0 3 . 1 3 . 2 3 . 3 3 . 4 3 . 5 3 . 6 3 . 7 T rain Upp er Bound 0 5 10 15 20 25 − 10 − 5 0 5 10 15 20 (1) Laten t Space 0 5 10 15 20 25 − 10 − 5 0 5 10 15 (1) Observ ations 0 5 10 15 20 25 − 10 − 5 0 5 10 15 20 25 (2) z KF ST-R 0 5 10 15 20 25 − 15 − 10 − 5 0 5 10 15 20 (2) x ST-R Figure 3: Synthetic Evaluation: ( T op & Middle ) Compiled inference for a fixed linear GSSM: z t ∼ N ( z t − 1 + 0 . 05 , 10) , x t ∼ N (0 . 5 z t , 20) . The training set comprised N = 5000 one- dimensional observations of sequence length T = 25 . (T op left) RMSE with respect to true z ∗ that generated the data. (T op right) V ariational bound during training. The results on held-out data are very similar (see supplementary material). (Bottom) V isualizing inference in two sequences (denoted (1) and (2)); Left panels sho w the Latent Space of variables z , right panels show the Observ ations x . Observations are generated by the application of the emission function to the posterior shown in Latent Space. Shading denotes standard deviations. 0 100 200 300 400 Ep o c hs 0 . 15 0 . 20 0 . 25 0 . 30 0 . 35 0 . 40 0 . 45 0 . 50 α α *=0.5 0 100 200 300 400 Ep o c hs − 0 . 12 − 0 . 10 − 0 . 08 − 0 . 06 − 0 . 04 − 0 . 02 0 . 00 β β *=-0.1 Figure 4: Parameter Estimation: Learning parameters α, β in a two-dimensional non-linear GSSM. N = 5000 , T = 25 ~ z t ∼ N ([0 . 2 z 0 t − 1 + tanh ( αz 1 t − 1 ); 0 . 2 z 1 t − 1 + sin( β z 0 t − 1 )] , 1 . 0) ~ x t ∼ N (0 . 5 ~ z t , 0 . 1) where ~ z denotes a vector , [] denotes concatena- tion and superscript denotes indexing. edges from x t − 1 to z t and from x t − 1 to x t . W e refer to the resulting generativ e model as DMM-Augmented (Aug.). Augmenting the DMM with additional edges realizes a richer class of generativ e models. W e sho w that DKS can be used as is for inference on a more complex generativ e model than DMM, while making gains in held-out lik elihood. All follo wing e xperiments use DKS for posterior inference. The baselines we compare to in T able 3 also ha ve more complex generative models than the DMM. STORN has edges from x t − 1 to z t giv en by the recurrence update and TSBN has edges from x t − 1 to z t as well as from x t − 1 to x t . T able 2: Comparing Inference Networks: T est negati ve log- likelihood on polyphonic music of different inference networks trained on a DMM with a fixed structure (lower is better). The numbers inside parentheses are the variational bound. Inference Network JSB Nottingham Piano Musedata DKS (i.e., ST -R ) 6.605 (7.033) 3.136 (3.327) 8.471 (8.584) 7.280 (7.136) ST -L 7.020 (7.519) 3.446 (3.657) 9.375 (9.498) 8.301 (8.495) ST -LR 6.632 (7.078) 3.251 (3.449) 8.406 (8.529) 7.127 (7.268) MF-LR 6.701 (7.101) 3.273 (3.441) 9.188 (9.297) 8.760 (8.877) T able 3: Evaluation against Baselines: T est negati v e log- likelihood (lo wer is better) on Polyphonic Music Generation dataset. T able Legend : RNN (Boulanger-le w andowski, Bengio, and V in- cent 2012), L V -RNN (Gu, Ghahramani, and Turner 2015), ST ORN (Bayer and Osendorfer 2014), TSBN, HMSBN (Gan et al. 2015). Methods JSB Nottingham Piano Musedata DMM 6.388 (6.926) {6.856} 2.770 (2.964) {2.954} 7.835 (7.980) {8.246} 6.831 (6.989) {6.203} DMM-Aug. 6.288 (6.773) {6.692} 2.679 (2.856) {2.872} 7.591 (7.721) {8.025} 6.356 (6.476) {5.766} HMSBN (8.0473) {7.9970} (5.2354) {5.1231} (9.563) {9.786} (9.741) {8.9012} STORN 6.91 2.85 7.13 6.16 RNN 8.71 4.46 8.37 8.13 TSBN {7.48} {3.67} {7.98} {6.81} L V -RNN 3.99 2.72 7.61 6.89 HMSBN shares the same structural properties as the DMM, but is learned using a simpler inference netw ork. In T able 3, as we increase the complexity of the generati v e model, we obtain better results across all datasets. The DMM outperforms both RNNs and HMSBN every- where, outperforms STORN on JSB, Nottingham and outper- form TSBN on all datasets except Piano. Compared to L V - RNN (that optimizes the inclusiv e KL-diver gence), DMM- Aug obtains better results on all datasets except JSB. This sho wcases our flexible, structured inference netw ork’ s ability to learn powerful generati ve models that compare fa v ourably to other state of the art models. W e provide audio files for samples from the learned DMM models in the code reposi- tory . 6.3 EHR Patient Data Learning models from large observ ational health datasets is a promising approach to adv ancing precision medicine and could be used, for example, to understand which medications work best, for whom. In this section, we sho w ho w a DMM may be used for precisely such an application. W orking with EHR data poses some technical challenges: EHR data are noisy , high dimensional and difficult to characterize easily . Patient data is rarely contiguous over large parts of the dataset and is often missing (not at random). W e learn a DMM on the data showing how to handle the aforementioned tech- nical challenges and use it for model based counterfactual prediction. Graphical Model: Fig. 5 represents the generati ve model we use when T = 4 . The model captures the idea of an underlying time-evolving latent state for a patient ( z t ) that is solely responsible for the diagnosis codes and lab v alues ( x t ) we observe. In addition, the patient state is modulated by drugs ( u t ) prescribed by the doctor . W e may assume that the drugs prescribed at any point in time depend on the patient’ s entire medical history though in practice, the dotted edges in the Bayesian network ne v er need to be modeled since x t and u t are always assumed to be observed. A natural line of follo w up work w ould be to consider learning when u t is missing or latent. W e make use of time-v arying (binary) drug prescription u t for each patient by augmenting the DMM with an additional edge every time step. Specifically , the DMM’ s transition function is now z t ∼ N ( G α ( z t − 1 , u t − 1 ) , S β ( z t − 1 , u t − 1 )) (cf. Eq. 1). In our data, each u t is an indicator vector of eight anti-diabetic drugs including Metformin and Insulin, where Metformin is the most commonly prescribed first-line anti-diabetic drug. z 1 u 1 x 1 z 2 u 2 x 2 z 3 u 3 x 3 z 4 x 4 Figure 5: DMM for Medical Data: The DMM (from Fig. 1) is augmented with external actions u t representing medications presented to the patient. z t is the latent state of the patient. x t are the observations that we model. Since both u t and x t are always assumed observed, the conditional distrib ution p ( u t | x 1 , . . . , x t − 1 ) may be ignored during learning. Emission & T ransition Function: The choice of emission and transition function to use for such data is not well un- derstood. In Fig. 6 (right), we experiment with v ariants of DMMs and find that using MLPs (rather than linear func- tions) in the emission and transition function yield the best generativ e models in terms of held-out likelihood. In these experiments, the hidden dimension was set as 200 for the emission and transition functions. W e used an RNN size of 400 and a latent dimension of size 50 . W e use the DKS as our inference network for learning. Learning with Missing Data: In the EHR dataset, a sub- set of the observ ations (such as A1C and Glucose v alues which are commonly used to assess blood-sugar lev els for diabetics) is frequently missing in the data. W e marginalize them out during learning, which is straightforw ard within the probabilistic semantics of our Bayesian network. The sub- network of the original graph we are concerned with is the emission function since missingness affects our ability to ev al- uate log p ( x t | z t ) (the first term in Eq. 6). The missing random variables are lea v es in the Bayesian sub-network (comprised of the emission function). Consider a simple example of two modeling two observations at time t , namely m t , o t . The log-likelihood of the data ( m t , o t ) conditioned on the latent 0 2 4 6 8 10 Time 0 . 5 0 . 6 0 . 7 0 . 8 0 . 9 1 . 0 Prop ortion of P atien ts High A1C w/ medication w/out medication 0 2 4 6 8 10 Time 0 . 5 0 . 6 0 . 7 0 . 8 0 . 9 1 . 0 High Glucose 0 200 400 600 800 1000 Ep o c hs 60 70 80 90 100 110 120 V alidate Upp er Bound T-[L]-E-[L] T-[NL]-E-[L] T-[L]-E-[NL] T-[NL]-E-[NL] Figure 6: (Left T w o Plots) Estimating Counterfactuals with DMM: The x-axis denotes the number of 3 -month intervals after prescrip- tion of Metformin. The y-axis denotes the proportion of patients (out of a test set size of 800 ) who, after their first prescription of Metformin, experienced a high lev el of A1C. In each tuple of bar plots at e very time step, the left aligned bar plots (green) represent the population that receiv ed diabetes medication while the right aligned bar plots (red) represent the population that did not receiv e diabetes medication. (Rightmost Plot) Upper bound on neg ativ e- log likelihood for dif ferent DMMs trained on the medical data. (T) denotes “transition”, (E) denotes “emission”, (L) denotes “linear” and (NL) denotes “non-linear”. v ariable z t decomposes as log p ( m t , o t | z t ) = log p ( m t | z t ) + log p ( o t | z t ) since the random v ariables are conditionally in- dependent gi ven their parent. If m is missing and mar ginal- ized out while o t is observed, then our log-likelihood is: log R m p ( m t , o t | z t ) = log( R m p ( m t | z t ) p ( o t | z t )) = log p ( o t | z t ) (since R m p ( m t | z t ) = 1 ) i.e we ef fecti vely ignore the missing observations when estimating the log- likelihood of the data. The Effect of Anti-Diabetic Medications: Since our co- hort comprises diabetic patients, we ask a counterfactual question: what would have happened to a patient had anti- diabetic drugs not been prescribed? Specifically we are in- terested in the patient’ s blood-sugar lev el as measured by the widely-used A1C blood-test. W e perform inference us- ing held-out patient data leading up to the time k of first prescription of Metformin. From the posterior mean, we per- form ancestral sampling tracking two latent trajectories: (1) the factual: where we sample ne w latent states conditioned on the medication u t the patient had actually recei v ed and (2) the counterfactual: where we sample conditioned on not receiving any drugs for all remaining timesteps (i.e u k set to the zero-vector). W e reconstruct the patient observations x k , . . . , x T , threshold the predicted v alues of A1C lev els into high and low and visualize the a v erage number of high A1C lev els we observ e among the synthetic patients in both sce- narios. This is an example of performing do-calculus (Pearl 2009) in order to estimate model-based counterfactual ef fects. The results are sho wn in Fig. 6. W e see the model learns that, on av erage, patients who were prescribed anti-diabetic medication had more controlled le vels of A1C than patients who did not receive any medication. Despite being an ag- gregate ef fect, this is interesting because it is a phenomenon that coincides with our intuition but was confirmed by the model in an entirely unsupervised manner . Note that in our dataset, most diabetic patients are indeed prescribed anti- diabetic medications, making the counterfactual prediction harder . The ability of this model to answer such queries opens up possibilities into b uilding personalized neural models of healthcare. Samples from the learned generati v e model and implementation details may be found in the supplement. 7 Discussion W e introduce a general algorithm for scalable learning in a rich f amily of latent v ariable models for time-series data. The underlying methodological principle we propose is to b uild the inference network to mimic the posterior distribution (under the generativ e model). The space complexity of our learning algorithm depends neither on the sequence length T nor on the training set size N , of fering massi v e sa vings compared to classical variational inference methods. Here we propose and ev aluate building variational infer- ence networks to mimic the structure of the true posterior distribution. Other structured v ariational approximations are also possible. For example, one could instead use an RNN from the past, conditioned on a summary statistic of the fu- ture, during learning and inference. Since we use RNNs only in the inference netw ork, it should be possible to continue to increase their capacity and condi- tion on different modalities that might be relev ant to approxi- mate posterior inference without w orry of o v erfitting the data. Furthermore, this confers us the ability to easily model in the presence of missing data since the semantics of the DMM render it easy to mar ginalize out unobserved data. In contrast, in a (stochastic) RNN (bottom in Fig. 1) it is much more dif ficult to marginalize out unobserved data due to the depen- dence of the intermediate hidden states on the pre vious input. Indeed this allowed us to de velop a principled application of the learning algorithm to modeling longitudinal patient data in EHR data and inferring treatment effect. Acknowledgements The T esla K40s used for this research were donated by the NVIDIA Corporation. The authors gratefully ackno wledge support by the D ARP A Probabilistic Programming for Ad- vancing Machine Learning (PP AML) Program under AFRL prime contract no. F A8750-14-C-0005, ONR #N00014-13-1- 0646, a NSF CAREER award #1350965, and Independence Blue Cross. W e thank Da vid Albers, K yungh yun Cho, Y acine Jernite, Eduardo Sontag and anonymous re viewers for their valuable feedback and comments. References Archer , E.; Park, I. M.; Buesing, L.; Cunningham, J.; and Paninski, L. 2015. Black box variational inference for state space models. arXiv pr eprint arXiv:1511.07367 . Bayer , J., and Osendorfer, C. 2014. Learning stochastic recurrent networks. arXiv preprint . Boulanger-le w ando wski, N.; Bengio, Y .; and V incent, P . 2012. Modeling temporal dependencies in high-dimensional se- quences: Application to polyphonic music generation and transcription. In ICML 2012 . Briegel, T ., and Tresp, V . 1999. Fisher scoring and a mixture of modes approach for approximate inference and learning in nonlinear state space models. In NIPS 1999 . Chung, J.; Gulcehre, C.; Cho, K.; and Bengio, Y . 2014. Empirical e v aluation of gated recurrent neural networks on sequence modeling. arXiv pr eprint arXiv:1412.3555 . Chung, J.; Kastner, K.; Dinh, L.; Goel, K.; Courville, A.; and Bengio, Y . 2015. A recurrent latent variable model for sequential data. In NIPS 2015 . Duckworth, D. 2016. Kalman filter , kalman smoother, and em library for python. https://pykalman.github.io/ . Accessed: 2016-02-24. Fabius, O., and van Amersfoort, J. R. 2014. V ariational recurrent auto-encoders. . Fraccaro, M.; Sønderby , S. K.; Paquet, U.; and Winther , O. 2016. Sequential neural models with stochastic layers. In NIPS 2016 . Gan, Z.; Li, C.; Henao, R.; Carlson, D. E.; and Carin, L. 2015. Deep temporal sigmoid belief networks for sequence modeling. In NIPS 2015 . Gao, Y .; Archer , E.; Paninski, L.; and Cunningham, J. P . 2016. Linear dynamical neural population models through nonlinear embeddings. In NIPS 2016 . Ghahramani, Z., and Roweis, S. T . 1999. Learning nonlinear dynamical systems using an EM algorithm. In NIPS 1999 . Gregor , K.; Danihelka, I.; Grav es, A.; Rezende, D. J.; and W ierstra, D. 2015. DRA W : A recurrent neural network for image generation. In ICML 2015 . Gu, S.; Ghahramani, Z.; and T urner , R. E. 2015. Neural adaptiv e sequential monte carlo. In NIPS 2015 . Hinton, G. E.; Dayan, P .; Frey , B. J.; and Neal, R. M. 1995. The" wake-sleep" algorithm for unsupervised neural net- works. Science 268. Johnson, M. J.; Duv enaud, D.; W iltschko, A. B.; Datta, S. R.; and Adams, R. P . 2016. Structured V AEs: Composing prob- abilistic graphical models and variational autoencoders. In NIPS 2016 . Kaae Sønderby, C.; Raiko, T .; Maaløe, L.; Kaae Sønderby, S.; and W inther, O. 2016. Ho w to T rain Deep V ariational Autoencoders and Probabilistic Ladder Networks. ArXiv e-prints . Kingma, D., and Ba, J. 2015. Adam: A method for stochastic optimization. In ICLR 2015 . Kingma, D. P ., and W elling, M. 2014. Auto-encoding varia- tional bayes. In ICLR 2014 . Larochelle, H., and Murray , I. 2011. The neural autoregres- siv e distrib ution estimator . In AIST A TS 2011 . Mnih, A., and Gregor , K. 2014. Neural variational inference and learning in belief networks. In ICML 2014 . Pearl, J. 2009. Causality . Cambridge uni versity press. Raiko, T ., and T ornio, M. 2009. V ariational bayesian learning of nonlinear hidden state-space models for model predictiv e control. Neur ocomputing 72(16):3704–3712. Raiko, T .; T ornio, M.; Honkela, A.; and Karhunen, J. 2006. State inference in variational bayesian nonlinear state-space models. In International Conference on ICA and Signal Separation 2006 . Rezende, D. J., and Mohamed, S. 2015. V ariational inference with normalizing flows. In ICML 2015 . Rezende, D. J.; Mohamed, S.; and W ierstra, D. 2014. Stochas- tic backpropagation and approximate inference in deep gen- erativ e models. In ICML 2014 . Schön, T . B.; W ills, A.; and Ninness, B. 2011. System identification of nonlinear state-space models. Automatica 47(1):39–49. Theano De velopment T eam. 2016. Theano: A Python frame work for fast computation of mathematical expressions. abs/1605.02688. T ran, D.; Ranganath, R.; and Blei, D. M. 2016. The varia- tional gaussian process. In ICLR 2016 . V alpola, H., and Karhunen, J. 2002. An unsupervised en- semble learning method for nonlinear dynamic state-space models. Neural computation 14(11):2647–2692. W an, E. A., and Nelson, A. T . 1996. Dual kalman filtering methods for nonlinear prediction, smoothing and estimation. In NIPS 1996 . W an, E.; V an Der Merwe, R.; et al. 2000. The unscented kalman filter for nonlinear estimation. In AS-SPCC 2000 . W atter, M.; Springenber g, J. T .; Boedecker , J.; and Riedmiller , M. 2015. Embed to control: A locally linear latent dynamics model for control from raw images. In NIPS 2015 . A ppendix A Lower Bound on the Likelihood of data W e can deriv e the bound on the likelihood L ( ~ x ; ( θ , φ )) as follows: log p θ ( ~ x ) ≥ Z ~ z q φ ( ~ z | ~ x ) log p θ ( ~ z ) p θ ( ~ x | ~ z ) q φ ( ~ z | ~ x ) d~ z = E q φ ( ~ z | ~ x ) [log p θ ( ~ x | ~ z )] − KL( q φ ( ~ z | ~ x ) || p θ ( ~ z )) ( Using x t ⊥ ⊥ x ¬ t | z t ) = T X t =1 E q φ ( z t | ~ x ) [log p θ ( x t | z t )] − KL( q φ ( ~ z | ~ x ) || p θ ( ~ z )) (7) = L ( ~ x ; ( θ , φ )) In the following we omit the dependence of q on ~ x , and omit the subscript φ . W e can show that the KL di ver gence between the approximation to the posterior and the prior simplifies as: KL( q ( z 1 , . . . , z T ) || p ( z 1 , . . . , z T )) = Z z 1 . . . Z z T q ( z 1 ) . . . q ( z T | z T − 1 ) log p ( z 1 , . . . , z T ) q ( z 1 ) ..q ( z T | z T − 1 ) (F actorization of the variational distribution) = Z z 1 . . . Z z T q ( z 1 ) . . . q ( z T | z T − 1 ) log p ( z 1 ) p ( z 2 | z 1 ) . . . p ( z T | z T − 1 ) q ( z 1 ) . . . q ( z T | z T − 1 ) (F actorization of the prior) = Z z 1 . . . Z z T q ( z 1 ) . . . q ( z T | z T − 1 ) log p ( z 1 ) q ( z 1 ) + T X t =2 Z z 1 . . . Z z T q ( z 1 ) . . . q ( z T | z T − 1 ) log p ( z t | z t − 1 ) q ( z t | z t − 1 ) = Z z 1 q ( z 1 ) log p ( z 1 ) q ( z 1 ) + T X t =2 Z z t − 1 Z z t q ( z t ) log p ( z t | z t − 1 ) q ( z t | z t − 1 ) (Each e xpectation over z t is constant for t / ∈ { t, t − 1 } ) = KL( q ( z 1 ) || p ( z 1 )) + T X t =2 E q ( z t − 1 ) [KL( q ( z t | z t − 1 ) || p ( z t | z t − 1 ))] (8) For e v aluating the marginal lik elihood on the test set, we can use the following Monte-Carlo estimate: p ( ~ x ) u 1 S S X s =1 p ( ~ x | ~ z ( s ) ) p ( ~ z ( s ) ) q ( ~ z ( s ) | ~ x ) ~ z ( s ) ∼ q ( ~ z | ~ x ) (9) This may be deriv ed in a manner akin to the one depicted in Appendix E (Rezende, Mohamed, and W ierstra 2014) or Appendix D (Kingma and W elling 2014). The log likelihood on the test set is computed using: log p ( ~ x ) u log 1 S S X s =1 exp log  p ( ~ x | ~ z ( s ) ) p ( ~ z ( s ) ) q ( ~ z ( s ) | ~ x )  (10) Eq. 10 may be computed in a numerically stable manner using the log-sum-exp trick. B KL diver gence between Prior and Posterior Maximum likelihood learning requires us to compute: KL( q ( z 1 , . . . , z T ) || p ( z 1 , . . . , z T )) = KL( q ( z 1 ) || p ( z 1 )) + T − 1 X t =2 E q ( z t − 1 ) [KL( q ( z t | q t − 1 ) || p ( z t | z t − 1 ))] (11) The KL div ergence between tw o multiv ariate Gaussians q , p with respectiv e means and co v ariances µ q , Σ q , µ p , Σ p can be written as: KL( q || p ) = 1 2 (log | Σ p | | Σ q | | {z } ( a ) − D + (12) T r(Σ − 1 p Σ q ) | {z } ( b ) + ( µ p − µ q ) T Σ − 1 p ( µ p − µ q ) | {z } ( c ) ) The choice of q and p is suggestiv e. using Eq. 11 & 12, we can deriv e a closed form for the KL div ergence between q ( z 1 . . . z T ) and p ( z 1 . . . z T ) . µ q , Σ q are the outputs of the variational model. Our functional form for µ p , Σ p is based on our generativ e and can be summarized as: µ p 1 = 0 Σ p 1 = 1 µ pt = G ( z t − 1 , u t − 1 ) = G t − 1 Σ pt = ∆ ~ σ Here, Σ pt is assumed to be a learned diagonal matrix and ∆ a scalar parameter . T erm (a) For t = 1 , we have: log | Σ p 1 | | Σ q 1 | = log | Σ p 1 |− log | Σ q 1 | = − log | Σ q 1 | (13) For t > 1 , we ha ve: log | Σ pt | | Σ q t | = log | Σ pt |− log | Σ q t | = D log (∆) + log | ~ σ |− log | Σ q t | (14) T erm (b) For t = 1 , we have: T r(Σ − 1 p 1 Σ q 1 ) = T r(Σ q 1 ) (15) For t > 1 , we ha ve: T r(Σ − 1 pt Σ q t ) = 1 ∆ T r(diag ( ~ σ ) − 1 Σ q t ) (16) T erm (c) For t = 1 , we have: ( µ p 1 − µ q 1 ) T Σ − 1 p 1 ( µ p 1 − µ q 1 ) = || µ q 1 || 2 (17) For t > 1 , we ha ve: ( µ pt − µ q t ) T Σ − 1 pt ( µ pt − µ q t ) = (18) ∆( G t − 1 − µ q t ) T diag( ~ σ ) − 1 ( G t − 1 − µ q t ) Rewriting Eq. 11 using Eqns. 13, 14, 15, 16, 17, 18, we get: KL( q ( z 1 , . . . , z T ) || p ( z 1 , . . . , z T )) = 1 2 (( T − 1) D log (∆) log | ~ σ |− T X t =1 log | Σ q t | + T r(Σ q 1 ) + 1 ∆ T X t =2 T r(diag ( ~ σ ) − 1 Σ q t ) + || µ q 1 || 2 + ∆ T X t =2 E z t − 1  ( G t − 1 − µ q t ) T diag( ~ σ ) − 1 ( G t − 1 − µ q t )  ) C Polyphonic Music Generation In the models we trained, the hidden dimension was set to be 100 for the emission distribution and 200 in the transi- tion function. W e typically used RNN sizes from one of { 400 , 600 } and a latent dimension of size 100 . Samples: Fig. 7 depicts mean probabilities of sam- ples from the DMM trained on JSB Chorales (Boulanger- lew ando wski, Bengio, and V incent 2012). MP3 songs corre- sponding to two dif ferent samples from the best DMM model in the main paper learned on each of the four polyphonic data sets may be found in the code repository . Experiments with NADE: W e also experimented with Neural Autoregressiv e Density Estimators (NADE) (Larochelle and Murray 2011) in the emission distribution for DMM-Aug and denote it DMM-Aug-N ADE. In T able 4, we see that DMM-Aug-N ADE performs comparably to the state of the art RNN-N ADE on JSB, Nottingham and Piano. T able 4: Experiments with NADE Emission: T est negati ve log- likelihood (lo wer is better) on Polyphonic Music Generation dataset. T able Legend : RNN-N ADE (Boulanger -le wando wski, Bengio, and V incent 2012) Methods JSB Nottingham Piano Musedata DMM-Aug.-NADE 5.118 (5.335) {5.264} 2.305 (2.347) {2.364} 7.048 (7.099) {7.361} 6.049 (6.115) {5.247} RNN-NADE 5.19 2.31 7.05 5.60 0 20 40 60 80 100 120 140 160 180 200 Time 0 10 20 30 40 50 60 70 80 88 (a) Sample 1 0 20 40 60 80 100 120 140 160 180 200 Time 0 10 20 30 40 50 60 70 80 88 (b) Sample 2 Figure 7: T wo samples from the DMM trained on JSB Chorales D Experimental Results on Synthetic Data Experimental Setup: W e used an RNN size of 40 in the inference networks used for the synthetic experiments. Linear SSMs : Fig. 8 (N=500, T=25) depicts the perfor- mance of inference networks using the same setup as in the main paper, only now using held out data to e v aluate the RMSE and the upper bound. W e find that the results echo those in the training set, and that on unseen data points, the inference networks, particularly the structured ones, are ca- pable of generalizing compiled inference. 0 50 100 150 200 250 300 350 Ep o c hs 1 2 3 4 5 6 V alidate RMSE ST-LR MF-LR ST-L ST-R MF-L KF [Exact] 0 50 100 150 200 250 300 350 Ep o c hs 3 . 1 3 . 2 3 . 3 3 . 4 3 . 5 V alidate Upp er Bound z t ∼ N ( z t − 1 + 0 . 05 , 10) x t ∼ N (0 . 5 z t , 20) Figure 8: Inference in a Linear SSM on Held-out Data: Performance of inference networks on held-out data using a generati ve model with Linear Emission and Linear T ransition (same setup as main paper) 0 50 100 150 200 250 300 350 Ep o c hs 1 2 3 4 5 6 T rain RMSE MF-LR ST-LR ST-L ST-R MF-L UKF 0 50 100 150 200 250 300 350 Ep o c hs 2 . 6 2 . 8 3 . 0 3 . 2 3 . 4 T rain Upp er Bound z t ∼ N (2 sin( z t − 1 ) + z t − 1 , 5) x t ∼ N (0 . 5 z t , 5) (a) Performance on training data 0 50 100 150 200 250 300 350 Ep o c hs 1 2 3 4 5 6 V alidate RMSE MF-LR ST-LR ST-L ST-R MF-L UKF 0 50 100 150 200 250 300 350 Ep o c hs 2 . 6 2 . 7 2 . 8 2 . 9 3 . 0 3 . 1 3 . 2 V alidate Upp er Bound z t ∼ N (2 sin( z t − 1 ) + z t − 1 , 5) x t ∼ N (0 . 5 z t , 5) (b) Performance on held-out data Figure 9: Inference in a Non-linear SSM: Performance of inference networks trained with data from a Linear Emission and Non-linear T ransition SSM Non-linear SSMs : Fig. 9 considers learning inference networks on a synthetic non-linear dynamical system ( N = 5000 , T = 25 ). W e find once again that inference networks that match the posterior realize faster con ver gence and better training (and validation) accurac y . 0 5 10 15 20 25 − 15 − 10 − 5 0 5 10 Data P oin t: (1) Laten t Space 0 5 10 15 20 25 − 10 − 8 − 6 − 4 − 2 0 2 4 6 8 Data P oin t: (1) Observ ations 0 5 10 15 20 25 − 10 − 5 0 5 10 15 Data P oin t: (2) z UKF ST-R 0 5 10 15 20 25 − 6 − 4 − 2 0 2 4 6 8 10 Data P oin t: (2) x ST-R Figure 10: Inference on Non-linear Synthetic Data: V isu- alizing inference on training data. Generativ e Models: (a) Linear Emission and Non-linear Transition z ∗ denotes the la- tent variable that generated the observ ation. x denotes the true data. W e compare against the results obtained by a smoothed Unscented Kalman Filter (UKF) (W an, V an Der Merwe, and others 2000). The column denoted “Observations" denotes the result of applying the emission function of the respective generativ e model on the posterior estimates sho wn in the column “Latent Space". The shaded areas surrounding each curve µ denotes µ ± σ for each plot. V isualizing Infer ence: In Fig. 10 we visualize the pos- terior estimates obtained by the inference network. W e run posterior inference on the training set 10 times and take the empirical expectation of the posterior means and cov ariances of each method. W e compare posterior estimates with those obtained by a smoothed Unscented Kalman Filter (UKF) (W an, V an Der Merwe, and others 2000). E Generative Models of Medical Data In this section, we detail some implementation details and vi- sualize samples from the generative model trained on patient data. Marginalizing out Missing Data: W e describe the method we use to implement the marginalization operation. The main paper notes that marginalizing out observations in the DMM corresponds to ignoring absent observations during learning. W e track indicators denoting whether A1C values and Glucose values were observed in the data. These are used as markers of missingness. During batch learning, at every time-step t , we obtain a matrix B = log p ( x t | z t ) of size batch-size × 48, where 48 is the dimensionality of the observations, comprising the log-lik elihoods of e very di- mension for patients in the batch. W e multiply this with a matrix of M . M has the same dimensions as B and has a 1 if the patient’ s A1C value was observed and a 0 otherwise. For dimensions that are nev er missing, M is always 1 . Sampling a Patient: W e visualize samples from the DMM trained on medical data in Fig. 11 The model cap- tures correlations within timesteps as well as v ariations in A1C le vel and Glucose level across timesteps. It also captures rare occurrences of comorbidities found amongst diabetic pa- tients. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 0 < A1C < 5.5 5.5 < A1C < 6.0 6.0 < A1C < 6.5 6.5 < A1C < 7.0 7.0 < A1C < 8.0 8.0 < A1C < 9.0 9.0 < A1C < 10.0 10.0 < A1C < 19.0 0 < GLUC. < 92.0 92.0 < GLUC. < 102.0 102.0 < GLUC. < 113.0 113.0 < GLUC. < 135.0 135.0 < GLUC. < 989.0 18 < A GE < 49.0 49.0 < A GE < 57.0 57.0 < A GE < 63.0 63.0 < A GE < 70.0 70.0 < A GE < 98.0 GENDER IS FEMALE CO VERA GE DIABETES W O CMP NT ST UNCNTR DIABETES W O CMP NT ST UNCNTRL DIABETES W O CMP UNCNTRLD GOUT NOS OBESITY NOS MORBID OBESITY ANEMIA IN CHR KIDNEY DIS OBSTR UCTIVE SLEEP APNEA MALIGNANT HYPER TENSION BENIGN HYP HT DIS W/O HF HYP HR T DIS NOS W/O HF COR ONAR Y A TH UNSP VSL NTV/GFT 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 0 < A1C < 5.5 5.5 < A1C < 6.0 6.0 < A1C < 6.5 6.5 < A1C < 7.0 7.0 < A1C < 8.0 8.0 < A1C < 9.0 9.0 < A1C < 10.0 10.0 < A1C < 19.0 0 < GLUC. < 92.0 92.0 < GLUC. < 102.0 102.0 < GLUC. < 113.0 113.0 < GLUC. < 135.0 135.0 < GLUC. < 989.0 18 < A GE < 49.0 49.0 < A GE < 57.0 57.0 < A GE < 63.0 63.0 < A GE < 70.0 70.0 < A GE < 98.0 GENDER IS FEMALE CO VERA GE DIABETES W O CMP NT ST UNCNTR DIABETES W O CMP NT ST UNCNTRL DIABETES W O CMP UNCNTRLD GOUT NOS OBESITY NOS MORBID OBESITY ANEMIA IN CHR KIDNEY DIS OBSTR UCTIVE SLEEP APNEA MALIGNANT HYPER TENSION BENIGN HYP HT DIS W/O HF HYP HR T DIS NOS W/O HF COR ONAR Y A TH UNSP VSL NTV/GFT Figure 11: Generated Samples Samples of a patient from the model, including the most important observations. The x-axis denotes time and the y-axis denotes the observations. The intensity of the color denotes its v alue between zero and one

Original Paper

Loading high-quality paper...

Comments & Academic Discussion

Loading comments...

Leave a Comment