Deep Amortized Inference for Probabilistic Programs
Probabilistic programming languages (PPLs) are a powerful modeling tool, able to represent any computable probability distribution. Unfortunately, probabilistic program inference is often intractable, and existing PPLs mostly rely on expensive, appro…
Authors: Daniel Ritchie, Paul Horsfall, Noah D. Goodman
Deep Amortized Infer ence f or Pr obabilistic Pr ograms Daniel Ritchie Stanford Univ ersity Paul Horsfall Stanford Univ ersity Noah D. Goodman Stanford Univ ersity Abstract Probabilistic programming languages (PPLs) are a po werful modeling tool, able to represent any computable probability distribution. Unfortunately , probabilistic program inference is often intractable, and existing PPLs mostly rely on e xpensiv e, approximate sampling-based methods. T o alle viate this problem, one could try to learn from past inferences, so that future inferences run faster . This strategy is known as amortized infer ence ; it has recently been applied to Bayesian net- works [ 28 , 22 ] and deep generativ e models [ 20 , 15 , 24 ]. This paper proposes a system for amortized inference in PPLs. In our system, amortization comes in the form of a parameterized guide pr ogr am . Guide programs have similar structure to the original program, b ut can hav e richer data flow , including neural network components. These networks can be optimized so that the guide approximately samples from the posterior distrib ution defined by the original program. W e present a flexible interface for defining guide programs and a stochastic gradient-based scheme for optimizing guide parameters, as well as some preliminary results on automatically deriving guide programs. W e explore in detail the common machine learning pattern in which a ‘local’ model is specified by ‘global’ random v alues and used to generate independent observed data points; this gi ves rise to amortized local inference supporting global model learning. 1 Introduction Probabilistic models provide a framework for describing abstract prior knowledge and using it to reason under uncertainty . Probabilistic programs are a po werful tool for probabilistic modeling. A probabilistic programming language (PPL) is a deterministic programming language augmented with random sampling and Bayesian conditioning operators. Performing inference on these programs then in v olves reasoning about the space of ex ecutions which satisfy some constraints, such as observed values. A universal PPL, one built on a T uring-complete language, can represent any computable probability distribution, including open-world models, Bayesian non-parameterics, and stochastic recursion [6, 19, 33]. If we consider a probabilistic program to define a distribution p ( x , y ) , where x are (latent) intermedi- ate v ariable and y are (observed) output data, then sampling from this distrib ution is easy: just run the program forward. Howe ver , computing the posterior distribution p ( x | y ) is hard, in volving an intractable integral. T ypically , PPLs provide means to approximate the posterior using Monte Carlo methods (e.g. MCMC, SMC), dynamic programming, or analytic computation. These inference methods are e xpensiv e because they (approximately) solve an intractable inte gral from scratch on ev ery separate in vocation. But many inference problems hav e shared structure: it is reasonable to expect that computing p ( x | y 1 ) should giv e us some information about how to compute p ( x | y 2 ) . In fact, there is reason to belie ve that this is ho w people are able to perform certain inferences, such as visual perception, so quickly—we ha ve percei ved the world many times before, and can le verage that accumulated kno wledge when presented with a new perception task [ 3 ]. This idea of using the results of pre vious inferences, or precomputation in general, to make later inferences more efficient is called amortized infer ence [3, 28]. Learning a generative model from many data points is a particularly important task that leads to many related inferences. One wishes to update global beliefs about the true generati ve model from individual data points (or batches of data points). While many algorithms are possible for this task, they all require some form of ‘parsing’ for each data point: doing posterior inference in the current generativ e model to guess values of local latent v ariable gi ven each observ ation. Because this local parsing inference is needed many many times, it is a good candidate for amortization. It is plausible that learning to do local inference via amortization w ould support faster and better global learning, which giv es more useful local inferences, leading to a virtuous cycle. This paper proposes a system for amortized inference in PPLs, and applies it to model learning. Instead of computing p ( x | y ) from scratch for each y , our system instead constructs a program q ( x | y ) which takes y as input and, when run forward, produces samples distrib uted approximately according to the true posterior p ( x | y ) . W e call q a guide pr ogram , follo wing terminology introduced in pre vious work [ 11 ]. The system can spend time up-front constructing a good approximation q so that at inference time, sampling from q is both fast and accurate. There is a huge space of possible programs q one might consider for this task. Rather than posing the search for q as a general program induction problem (as was done in pre vious work [ 11 ]), we restrict q to hav e the same control flo w as the original program p , but a dif ferent data flow . That is, q samples the same random choices as p and in the same order , but the data flo wing into those choices comes from a dif ferent computation. In our system, we represent this computation using neural networks. This design choice reduces the search for q to the much simpler continuous problem of optimizing the weights for these networks, which can be done using stochastic gradient descent. Our system’ s interface for specifying guide programs is flexible enough to subsume se veral popular recent approaches to v ariational inference, including those that perform both inference and model learning. T o facilitate this common pattern we introduce the mapData construct which represents the boundary between global “model” v ariables and variables local to the data points. Our system lev erages the independence between data points implied by mapData to enable mini-batches of data and v ariance reduction of gradient estimates. W e ev aluate our proof-of-concept system on a variety of Bayesian networks, topic models, and deep generati ve models. Our system has been implemented as an extension to the W ebPPL probabilistic programming language [ 7 ]. Its source code can be found in the W ebPPL repository , with additional helper code at https://github.com/probmods/webppl- daipp . 2 Background 2.1 Probabilistic Pr ogramming Basics For our purposes, a probabilistic program defines a generative model p ( x , y ) of latent variables x and data y . The model factors as: p ( x , y ) = p ( y | x ) Y i p ( x i | x }), sigma: softplus(param({name: }))}) 3 }) where parameter bounding transforms such as softplus are applied based on bounds metadata provided with each primiti ve distrib ution type. W e use reparameterizable guides for continuous distrib utions (see Appendix A). Since this process declares new optimizable parameters automatically , we must automatically generate names for these parameters. Our system names parameters according to where they are declared in the program ex ecution trace, using the same naming technique as is used for random choices in probabilistic programming MCMC engines [ 31 ]. Since the names of these parameters are tied to the structure of the program, they cannot be re-used by other programs (as in the ‘Further Optimization’ example of Section 5.3). 17 7.2 Beyond Mean Field: A utomatic Factor ed Guides with Recurrent Netw orks In Section 6.2, we experimented with a factored guide program for the QMR–DT model. W e think that this general style of guide—predicting each random choice in sequence, conditional on the hidden state of a recurrent neural network—might be generalized to an automatic guide for any program, as any probabilistic program can be decomposed into a sequence of random choices. In our QMR-DT experiments, we used a separate neural network (with separate parameters) to predict each latent v ariable (i.e. random choice). For complex models and lar ge data sets, this approach would lead to a computationally unfeasible explosion in the number of parameters. Furthermore, it is likely that the prediction computations for many random choices in the program are related. For example, in the QMR-DT program, latent causes that share many dependent effects may be well predicted by the same or very similar networks. Giv en these insights, we imagine a universally-applicable guide that uses a single prediction network for all random choices, but to which each random choice provides an additional identifying input. These IDs should be elements in a v ector space, such that more ‘similar’ random choices hav e IDs which are close to one another for some distance metric in the vector space. One possible way to obtain such IDs would be to learn an embedding of the program-structural addresses of each random choice [ 31 ]. These might be learned in an end-to-end fashion by making them learnable parameter vectors in the ov erall variational optimization (i.e. letting closeness in the embedding space be an emergent property of optimizing our o verall objecti ve). 8 Conclusion In this paper , we presented a system for amortized inference in probabilistic programs. Amortization is achiev ed through parameterized guide pr ograms which mirror the structure of the original program but can be trained to approximately sample from the posterior . W e introduced an interface for specifying guide programs which is flexible enough to reproduce state-of-the-art v ariational inference methods. W e also demonstrated how this interface supports model learning in addition to amortized inference. W e de veloped and prov ed the correctness of an optimization method for training guide programs, and we ev aluated its ability to optimize guides for Bayesian networks, topic models, and deep generativ e models. 8.1 Future W ork There are many exciting directions of future work to pursue in improving amortized inference for probabilistic programs. The system we hav e presented in this paper provides a platform from which to explore these and other possibilities: More modeling paradigms In this paper , we focused on the common machine learning modeling paradigm in which a global generativ e model generates many IID data points. There are many other modeling paradigms to consider . For example, time series data is common in machine learning applications. Just as we de veloped mapData to facilitate ef ficient inference in IID data models, we might de velop an analogous data processing function for time series data (i.e. foldData ). Using neural guides with such a setup would permit amortized inference in models such as Deep Kalman Filters [ 16 ]. In computer vision and computer graphics, a common paradigm for generati ve image models is to factor image generation into multiple steps and condition each step on the partially-generated image thus far [29, 25]. Such ‘yield-so-far’ models should also be possible to implement in our system. Better gradient estimators While the v ariance reduction strategies employed by our optimizer make inference with discrete v ariables tractable, it is still noticeably less ef ficient then with purely continuous models. Fortunately , there are ongoing ef forts to develop better, general-purpose dis- crete estimators for stochastic gradients [ 10 , 21 ]. It should be possible to adapt these methods for probabilistic programs. A utomatic guides As discussed in Section 7, we belie ve that automatically deri ving guide programs using recurrent neural networks may soon be possible. Recent enhancements to recurrent networks may be necessary to make this a reality . For example, the e xternal memory of the Neural T uring Machine may be better at capturing certain long-range posterior dependencies [ 8 ]. W e might also 18 draw inspiration from the Neural Programmer -Interpreter [ 23 ], whose stack of recurrent networks which communicate via arguments might better capture the posterior dataflow of arbitrary programs. Other learning objectives In this paper, we focused on optimizing the ELBo. If we flip the direction of KL div ergence in Equation 2, the resulting functional is an upper bound on the log marginal likelihood of the data—an ‘Evidence Upper Bound, ’ or EUBo. Computing the EUBo and its gradient requires samples from the true posterior and is thus unusable in many applications, where the entire goal of amortized inference is to find a way to tractably generate such samples. Howe ver , some applications can benefit from it, if the goal is to speed up an existing tractable inference algorithm (e.g. SMC [ 25 ]), or if posterior ex ecution traces are a vailable through some other means (e.g. input examples from the user). There may also be less extreme ways to exploit this idea for learning. For e xample, in a mapData -style program, we might interlea ve normal ELBo updates with steps that hallucinate data from the posterior predicti ve (using a guide for global model parameters) and train the local guide to correctly parse these ‘dreamed-up’ examples. Such a scheme bears resemblance to the wake-sleep algorithm [12]. Control flo w While our system’ s one-to-one mapping between random choices in the guide and in the target program makes the definition and analysis of guides simple, there are scenarios in which more flexibility is useful. In some cases, one may want to insert random choices into the guide which do not occur in the tar get program (e.g. using a compound distribution, such as a mixture distrib ution, as a guide). And for models in which there is a natural hierarchy between the latent variables and the observed v ariables, having the guide run ‘backwards’ from the observ ed v ariables to the top-most latents has been sho wn to be useful [ 28 , 22 , 20 ]. It is worth exploring ho w to support these (and possibly ev en more general) control flow de viations in a general-purpose probabilistic programming inference system. Acknowledgments This material is based on research sponsored by D ARP A under agreement number F A8750-14-2-0009. The U.S. Gov ernment is authorized to reproduce and distrib ute reprints for Gov ernmental purposes notwithstanding any copyright notation thereon. The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies or endorsements, either expressed or implied, of D ARP A or the U.S. Government. References [1] Kyungh yun Cho, Bart van Merrienboer , Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, and Y oshua Bengio. Learning Phrase Representations using RNN Encoder - Decoder for Statistical Machine T ranslation. In EMNLP 2014 . [2] M. C. Fu. Gradient Estimation. Handbooks in operations r esearc h and manag ement science , 13, 2006. [3] Sam Gershman and Noah D. Goodman. Amortized Inference in Probabilistic Reasoning. In Pr oceedings of the Thirty-Sixth Annual Conference of the Co gnitive Science Society , 2014. [4] P . Glasserman. Monte Carlo Methods in F inancial Engineering . Springer Science & Business, 2003. [5] P . W . Glynn. Likelihood ratio gradient estimation for stochastic systems. Communications of the A CM , 33(10), 1990. [6] Noah D. Goodman, V ikash K. Mansinghka, Daniel M. Roy , Keith Bonawitz, and Joshua B. T enenbaum. Church: a language for generati ve models. In UAI 2008 . [7] Noah D. Goodman and Andreas Stuhlmüller . The Design and Implementation of Probabilistic Programming Languages. http://dippl.org , 2014. Accessed: 2015-12-23. [8] Alex Graves, Greg W ayne, and Ivo Danihelka. Neural T uring Machines. CoRR , arXiv:1410.5401, 2014. [9] E. Greensmith, P . L. Bartlett, and J. Baxter . V ariance reduction techniques for gradient estimates in reinforcement learning. The Journal of Mac hine Learning Resear ch , 5, 2004. 19 [10] Shixiang Gu, Serge y Levine, Ilya Sutsk e ver , and Andriy Mnih. MuProp: Unbiased Backpropa- gation for Stochastic Neural Networks. In ICLR 2016 . [11] Georges Harik and Noam Shazeer . V ariational Program Inference. CoRR , 2010. [12] GE Hinton, P Dayan, BJ Frey , and RM Neal. The "wake-sleep" algorithm for unsupervised neural networks. Science , 268, 1995. [13] M. Jordan, Z. Ghahramani, T . Jaakkola, and L. Saul. Introduction to variat ional methods for graphical models. Machine Learning , 37, 1999. [14] Diederik P . Kingma and Jimmy Ba. Adam: A Method for Stochastic Optimization. In ICLR 2015 . [15] Diederik P . Kingma and Max W elling. Auto-Encoding V ariational Bayes. In ICLR 2014 . [16] Rahul G. Krishnan, Uri Shalit, and Da vid Sontag. Deep Kalman Filters. CoRR , arXiv:1511.05121, 2015. [17] Alp Kucuk elbir , Rajesh Ranganath, Andrew Gelman, and Da vid M. Blei. Automatic V ariational Inference in Stan. In NIPS 2015 . [18] J. Manning, R. Ranganath, K. Norman, and D. Blei. Black Box V ariational Inference. In AIST ATS 2014 . [19] V ikash K. Mansinghka, Daniel Selsam, and Y ura N. Perov . V enture: a higher-order probabilistic programming platform with programmable inference. CoRR , arXiv:1404.0099, 2014. [20] Andriy Mnih and Karol Gregor . Neural V ariational Inference and Learning in Belief Networks. In ICML 2014 . [21] Andriy Mnih and Danilo J. Rezende. V ariational inference for Monte Carlo objectiv es. In ICML 2016 . [22] B. Paige and F . W ood. Inference Networks for Sequential Monte Carlo in Graphical Models. In ICML 2016 . [23] Scott Reed and Nando de Freitas. Neural Programmer-Interpreters. In ICLR 2016 . [24] Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra. Stochastic Backpropagation and Approximate Inference in Deep Generativ e Models. In ICML 2014 . [25] Daniel Ritchie, Anna Thomas, Pat Hanrahan, and Noah D. Goodman. Neurally-Guided Procedural Models: Amortized Inference for Procedural Graphics Programs using Neural Networks. In NIPS 2016 . [26] John Schulman, Nicolas Hess, Theophane W eber , and Pieter Abbeel. Gradient Estimation Using Stochastic Computation Graphs. In NIPS 2015 . [27] M. Shwe, B. Middleton, D. Heckerman, M. Henrion, E. Horvitz, H. Lehmann, and G. Cooper . Probabilistic diagnosis using a reformulation of the INTERNIST -1/QMR knowledge base. I. The probabilistic model and inference algorithms. Methods of Information in Medicine , 30. [28] Andreas Stuhlmüller , Jessica T aylor , and Noah D. Goodman. Learning Stochastic In verses. In NIPS 2013 . [29] Aaron van den Oord, Nal Kalchbrenner , Oriol V inyals, Lasse Espeholt, Ale x Graves, and K oray Kavukcuoglu. Conditional Image Generation with PixelCNN Decoders. CoRR , arXiv:1606.05328, 2016. [30] Ronald J. W illiams. Simple statistical gradient-following algorithms for connectionist reinforce- ment learning. Machine Learning , 8, 1992. [31] David W ingate, Andreas Stuhlmüller , and Noah D. Goodman. Lightweight Implementations of Probabilistic Programming Languages V ia T ransformational Compilation. In AIST A TS 2011 . [32] David W ingate and Theophane W eber . Automated V ariational Inference in Probabilistic Pro- gramming. In NIPS 2012 W orkshop on Pr obabilistic Pr ogramming . [33] F . W ood, J. W . van de Meent, and V . Mansinghka. A New Approach to Probabilistic Program- ming Inference. In AIST A TS 2014 . 20 A A ppendix: Reparameterizations Examples of primiti ve random choice distrib utions that can be reparameterized via a location-scale transform: Distribution ∼ r ( · ) g ( ) Gaussian( µ , σ ) Gaussian( 0 , 1 ) µ + σ · LogitNormal( µ , σ ) Gaussian( 0 , 1 ) sigmoid ( µ + σ · ) LogisticNormal( µ , σ ) Gaussian( 0 , 1 ) simplex ( µ + σ · ) In verseSoftplusNormal( µ , σ ) Gaussian( 0 , 1 ) softplus ( µ + σ · ) Exponential( λ ) Uniform( 0 , 1 ) − log( ) /λ Cauchy( x 0 , γ ) Uniform( 0 , 1 ) x 0 + γ · tan( π · ( − 0 . 5)) Examples of primitiv e distributions that do not ha ve a location-scale transform b ut can be guided by a reparameterizable approximating distribution: Distribution Guide Distribution Uniform LogitNormal Beta LogitNormal Gamma In verseSoftplusNormal Dirichlet LogisticNormal B A ppendix: Gradient Estimator Derivations & Correctness Pr oofs B.1 Derivation of Unified Gradient Estimator (Equation 5) ∇ φ L ( y ) = ∇ φ E r [log p ( g ( ) , y ) − log q ( g ( ) | y )] = ∇ φ Z r ( | y )(log p ( g ( ) , y ) − log q ( g ( ) | y )) = Z ∇ φ r ( | y )(log p ( g ( ) , y ) − log q ( g ( ) | y )) + r ( | y ) ∇ φ (log p ( g ( ) , y ) − log q ( g ( ) | y )) = Z r ( | y ) ∇ φ log r ( | y )(log p ( g ( ) , y ) − log q ( g ( ) | y )) + r ( | y ) ∇ φ (log p ( g ( ) , y ) − log q ( g ( ) | y )) (7) = E r [ ∇ φ log r ( | y )(log p ( g ( ) , y ) − log q ( g ( ) | y )) + ∇ φ (log p ( g ( ) , y ) − log q ( g ( ) | y ))] Line 7 makes use of the identity ∇ f ( x ) = f ( x ) ∇ log f ( x ) . B.2 Zero Expectation Identities In what follows, we will mak e frequent use of the follo wing: Lemma 1. If f ( x ) is a pr obability distribution, then: E f [ ∇ log f ( x )] = 0 Pr oof. E f [ ∇ log f ( x )] = Z x f ( x ) ∇ log f ( x ) = Z x ∇ f ( x ) = ∇ Z x f ( x ) = ∇ 1 = 0 21 Lemma 2. F or a discr ete random c hoice i and a function f ( i r ( >i | ≤ i , y ) = Z 0) { 14 // Sum over topic assignments/z. 15 var prob = sum(mapN( function (z) { 16 var zScore = Discrete({ps: topicDist}).score(z); 17 var wgivenzScore = Discrete({ps: topics[z]}).score(word); 18 return Math.exp(zScore + wgivenzScore); 19 }, numTopics)); 20 21 factor(Math.log(prob) * count); 22 } 23 24 }); 25 26 }); 27 28 return topics; 29 30 }; W ord-level guide: 1 var model = function (corpus, vocabSize, numTopics, alpha, eta) { 2 3 var numHid = 50; 4 var embedSize = 50; 5 var embedNet = nn.mlp(vocabSize, [{nOut: embedSize, activation: nn.tanh}], ’embedNet’ ); 6 7 var net = nn.mlp(embedSize + numTopics, [ 8 {nOut: numHid, activation: nn.tanh}, 9 {nOut: numTopics} 10 ], ’net’ ); 11 12 var wordAndTopicDistToParams = function (word, topicDist) { 13 var embedding = nneval(embedNet, oneOfK(word, vocabSize)); 14 var out = nneval(net, T.concat(embedding, T.sub(topicDist, 1))); 15 return {ps: softplus(tensorToVector(out))}; 28 16 }; 17 18 var topics = repeat(numTopics, function () { 19 return sample(Dirichlet({alpha: eta})); 20 }); 21 22 mapData({data: corpus}, function (doc) { 23 24 var topicDist = sample(Dirichlet({alpha: alpha})); 25 26 mapData({data: doc}, function (word) { 27 var z = sample(Discrete({ps: topicDist}), { 28 guide: Discrete(wordAndTopicDistToParams(word, topicDist)) 29 }); 30 var topic = topics[z]; 31 observe(Discrete({ps: topic}), word); 32 }); 33 34 }); 35 36 return topics; 37 }; Document-lev el guide: 1 var nets = cache( function (numHid, vocabSize, numTopics) { 2 var init = nn.constantparams([numHid], ’init’ ); 3 4 var ru = makeRNN(numHid, vocabSize, ’ru’ ); 5 6 var outputHidden = nn.mlp(numHid, [ 7 {nOut: numHid, activation: nn.tanh} 8 ], ’outputHidden’ ); 9 10 var outputMu = nn.mlp(numHid, [ 11 {nOut: numTopics - 1} 12 ], ’outputMu’ ); 13 14 var outputSigma = nn.mlp(numHid, [ 15 {nOut: numTopics - 1} 16 ], ’outputSigma’ ); 17 18 return { 19 init: init, 20 ru: ru, 21 outputHidden: outputHidden, 22 outputMu: outputMu, 23 outputSigma: outputSigma 24 }; 25 }); 26 27 var model = function (data, vocabSize, numTopics, alpha, eta) { 28 var corpus = data.documentsAsCounts; 29 var numHid = 20; 30 var nets = nets(numHid, vocabSize, numTopics); 31 32 var guideParams = function (topics, doc) { 33 var initialState = nneval(nets.init); 34 var state = reduce( function (x, prevState) { 35 return nneval(nets.ru, [prevState, x]); 36 }, initialState, topics.concat(normalize(Vector(doc)))); 37 var hidden = nneval(nets.outputHidden, state); 38 var mu = tensorToVector(nneval(nets.outputMu, hidden)); 39 var sigma = tensorToVector(softplus(nneval(nets.outputSigma, hidden))); 40 var params = {mu: mu, sigma: sigma}; 41 return params; 42 }; 43 44 var topics = repeat(numTopics, function () { 45 return sample(Dirichlet({alpha: eta})); 46 }); 47 48 mapData({data: corpus}, function (doc) { 49 50 var topicDist = sample(Dirichlet({alpha: alpha}), { 51 guide: LogisticNormal(guideParams(topics, doc)) 52 }); 53 29 54 mapData({data: countsToIndices(doc)}, function (word) { 55 var z = sample(Discrete({ps: topicDist})); 56 var topic = topics[z]; 57 observe(Discrete({ps: topic}), word); 58 }); 59 60 }); 61 62 return topics; 63 }; C.4 V ariational A utoencoder In this example and the one that follo ws, nnevalModel ev aluates a neural network while also placing an improper uniform prior ov er the network parameters. This allows neural networks to be used as part of learnable models. 1 // Data 2 var data = loadData( ’mnist.json’ ); 3 var dataDim = 28*28; 4 var hiddenDim = 500; 5 var latentDim = 20; 6 7 // Encoder 8 var encodeNet = nn.mlp(dataDim, [ 9 {nOut: hiddenDim, activation: nn.tanh} 10 ], ’encodeNet’ ); 11 var muNet = nn.linear(hiddenDim, latentDim, ’muNet’ ); 12 var sigmaNet = nn.linear(hiddenDim, latentDim, ’sigmaNet’ ); 13 var encode = function (image) { 14 var h = nneval(encodeNet, image); 15 return { 16 mu: nneval(muNet, h), 17 sigma: softplus(nneval(sigmaNet, h)) 18 }; 19 }; 20 21 // Decoder 22 var decodeNet = nn.mlp(latentDim, [ 23 {nOut: hiddenDim, activation: nn.tanh}, 24 {nOut: dataDim, activation: nn.sigmoid} 25 ], ’decodeNet’ ); 26 var decode = function (latent) { 27 return nnevalModel(decodeNet, latent); 28 }; 29 30 // Training model 31 var model = function () { 32 mapData({data: data, batchSize: 100}, function (image) { 33 // Sample latent code (guided by encoder) 34 var latent = sample(TensorGaussian({mu: 0, sigma: 1, dims: [latentDim, 1]}), { 35 guide: DiagCovGaussian(encode(image)) 36 }); 37 38 // Decode latent code, observe binary image 39 var probs = decode(latent); 40 observe(MultivariateBernoulli({ps: probs}), image); 41 }); 42 } C.5 Sigmoid Belief Network 1 // Data 2 var data = loadData( ’mnist.json’ ); 3 var dataDim = 28*28; 4 var latentDim = 200; 5 6 // Encoder 7 var encodeNet = nn.mlp(dataDim, [ 8 {nOut: latentDim, activation: nn.sigmoid} 9 ], ’encodeNet’ ); 10 var encode = function (image) { 11 return nneval(encodeNet, image) 12 }; 30 13 14 // Decoder 15 var decodeNet = nn.mlp(latentDim, [ 16 {nOut: dataDim, activation: nn.sigmoid} 17 ], ’decodeNet’ ); 18 var decode = function (latent) { 19 return nnevalModel(decodeNet, latent); 20 }; 21 22 // Training model 23 var priorProbs = Vector(repeat(latentDim, function () { return 0.5; })); 24 var model = function () { 25 mapData({data: data, batchSize: 100}, function (image) { 26 // Sample latent code (guided by encoder) 27 var latent = sample(MultivariateBernoulli({ps: priorProbs}), { 28 guide: MultivariateBernoulli({ps: encode(image)}) 29 }); 30 31 // Decode latent code, observe binary image 32 var probs = decode(latent); 33 observe(MultivariateBernoulli({ps: probs}), image); 34 }); 35 } 31
Original Paper
Loading high-quality paper...
Comments & Academic Discussion
Loading comments...
Leave a Comment