Deep Amortized Inference for Probabilistic Programs

Probabilistic programming languages (PPLs) are a powerful modeling tool, able to represent any computable probability distribution. Unfortunately, probabilistic program inference is often intractable, and existing PPLs mostly rely on expensive, appro…

Authors: Daniel Ritchie, Paul Horsfall, Noah D. Goodman

Deep Amortized Inference for Probabilistic Programs
Deep Amortized Infer ence f or Pr obabilistic Pr ograms Daniel Ritchie Stanford Univ ersity Paul Horsfall Stanford Univ ersity Noah D. Goodman Stanford Univ ersity Abstract Probabilistic programming languages (PPLs) are a po werful modeling tool, able to represent any computable probability distribution. Unfortunately , probabilistic program inference is often intractable, and existing PPLs mostly rely on e xpensiv e, approximate sampling-based methods. T o alle viate this problem, one could try to learn from past inferences, so that future inferences run faster . This strategy is known as amortized infer ence ; it has recently been applied to Bayesian net- works [ 28 , 22 ] and deep generativ e models [ 20 , 15 , 24 ]. This paper proposes a system for amortized inference in PPLs. In our system, amortization comes in the form of a parameterized guide pr ogr am . Guide programs have similar structure to the original program, b ut can hav e richer data flow , including neural network components. These networks can be optimized so that the guide approximately samples from the posterior distrib ution defined by the original program. W e present a flexible interface for defining guide programs and a stochastic gradient-based scheme for optimizing guide parameters, as well as some preliminary results on automatically deriving guide programs. W e explore in detail the common machine learning pattern in which a ‘local’ model is specified by ‘global’ random v alues and used to generate independent observed data points; this gi ves rise to amortized local inference supporting global model learning. 1 Introduction Probabilistic models provide a framework for describing abstract prior knowledge and using it to reason under uncertainty . Probabilistic programs are a po werful tool for probabilistic modeling. A probabilistic programming language (PPL) is a deterministic programming language augmented with random sampling and Bayesian conditioning operators. Performing inference on these programs then in v olves reasoning about the space of ex ecutions which satisfy some constraints, such as observed values. A universal PPL, one built on a T uring-complete language, can represent any computable probability distribution, including open-world models, Bayesian non-parameterics, and stochastic recursion [6, 19, 33]. If we consider a probabilistic program to define a distribution p ( x , y ) , where x are (latent) intermedi- ate v ariable and y are (observed) output data, then sampling from this distrib ution is easy: just run the program forward. Howe ver , computing the posterior distribution p ( x | y ) is hard, in volving an intractable integral. T ypically , PPLs provide means to approximate the posterior using Monte Carlo methods (e.g. MCMC, SMC), dynamic programming, or analytic computation. These inference methods are e xpensiv e because they (approximately) solve an intractable inte gral from scratch on ev ery separate in vocation. But many inference problems hav e shared structure: it is reasonable to expect that computing p ( x | y 1 ) should giv e us some information about how to compute p ( x | y 2 ) . In fact, there is reason to belie ve that this is ho w people are able to perform certain inferences, such as visual perception, so quickly—we ha ve percei ved the world many times before, and can le verage that accumulated kno wledge when presented with a new perception task [ 3 ]. This idea of using the results of pre vious inferences, or precomputation in general, to make later inferences more efficient is called amortized infer ence [3, 28]. Learning a generative model from many data points is a particularly important task that leads to many related inferences. One wishes to update global beliefs about the true generati ve model from individual data points (or batches of data points). While many algorithms are possible for this task, they all require some form of ‘parsing’ for each data point: doing posterior inference in the current generativ e model to guess values of local latent v ariable gi ven each observ ation. Because this local parsing inference is needed many many times, it is a good candidate for amortization. It is plausible that learning to do local inference via amortization w ould support faster and better global learning, which giv es more useful local inferences, leading to a virtuous cycle. This paper proposes a system for amortized inference in PPLs, and applies it to model learning. Instead of computing p ( x | y ) from scratch for each y , our system instead constructs a program q ( x | y ) which takes y as input and, when run forward, produces samples distrib uted approximately according to the true posterior p ( x | y ) . W e call q a guide pr ogram , follo wing terminology introduced in pre vious work [ 11 ]. The system can spend time up-front constructing a good approximation q so that at inference time, sampling from q is both fast and accurate. There is a huge space of possible programs q one might consider for this task. Rather than posing the search for q as a general program induction problem (as was done in pre vious work [ 11 ]), we restrict q to hav e the same control flo w as the original program p , but a dif ferent data flow . That is, q samples the same random choices as p and in the same order , but the data flo wing into those choices comes from a dif ferent computation. In our system, we represent this computation using neural networks. This design choice reduces the search for q to the much simpler continuous problem of optimizing the weights for these networks, which can be done using stochastic gradient descent. Our system’ s interface for specifying guide programs is flexible enough to subsume se veral popular recent approaches to v ariational inference, including those that perform both inference and model learning. T o facilitate this common pattern we introduce the mapData construct which represents the boundary between global “model” v ariables and variables local to the data points. Our system lev erages the independence between data points implied by mapData to enable mini-batches of data and v ariance reduction of gradient estimates. W e ev aluate our proof-of-concept system on a variety of Bayesian networks, topic models, and deep generati ve models. Our system has been implemented as an extension to the W ebPPL probabilistic programming language [ 7 ]. Its source code can be found in the W ebPPL repository , with additional helper code at https://github.com/probmods/webppl- daipp . 2 Background 2.1 Probabilistic Pr ogramming Basics For our purposes, a probabilistic program defines a generative model p ( x , y ) of latent variables x and data y . The model factors as: p ( x , y ) = p ( y | x ) Y i p ( x i | x