Online Posterior Sampling with a Diffusion Prior
Posterior sampling in contextual bandits with a Gaussian prior can be implemented exactly or approximately using the Laplace approximation. The Gaussian prior is computationally efficient but it cannot describe complex distributions. In this work, we propose approximate posterior sampling algorithms for contextual bandits with a diffusion model prior. The key idea is to sample from a chain of approximate conditional posteriors, one for each stage of the reverse diffusion process, which are obtained by the Laplace approximation. Our approximations are motivated by posterior sampling with a Gaussian prior, and inherit its simplicity and efficiency. They are asymptotically consistent and perform well empirically on a variety of contextual bandit problems.
💡 Research Summary
This paper tackles the problem of posterior sampling in contextual multi‑armed bandits, where the traditional approach relies on a multivariate Gaussian prior. While Gaussian priors enable closed‑form updates or efficient Laplace approximations, they are fundamentally limited in expressive power and cannot capture multimodal or highly non‑linear structures that may be present in realistic parameter spaces. To overcome this limitation, the authors propose to replace the Gaussian prior with a diffusion model prior—a generative model that learns complex distributions by progressively adding noise to data (the forward diffusion) and then learning to reverse this process (the reverse diffusion).
The core contribution is an algorithmic framework called Laplace Diffusion Posterior Sampling (LaplaceDPS). The idea is to treat the diffusion model’s reverse process as a Markov chain of latent variables (S_T, S_{T-1}, \dots, S_0), where (S_0) corresponds to the model parameter (\theta). Given a history of observations (h = {(\phi_\ell, y_\ell)}{\ell=1}^N), the posterior (p(\theta \mid h)) can be expressed as the joint distribution of the chain conditioned on (h). By the Markov property, this joint factorizes into an initial distribution (p(S_T \mid h)) and a product of conditional distributions (p(S{t-1}\mid S_t, h)) for each stage (t).
Direct computation of these conditionals is intractable because they involve integrals over the clean sample (S_0). The authors introduce a key approximation (Equation 6) that leverages the forward diffusion relationship (S_t = \sqrt{\bar\alpha_t} S_0 + \sqrt{1-\bar\alpha_t},\varepsilon_t). By treating (S_0) as approximately (S_t / \sqrt{\bar\alpha_t}) (i.e., using the mean of the forward diffusion), the integral collapses and the likelihood becomes a function of the diffused variable (S_t). This approximation becomes increasingly accurate as the reverse process approaches the clean sample (small (t)).
With this approximation, the conditional posterior at each stage becomes a product of two Gaussians: one coming from the pre‑trained diffusion model (the “prior” term) and the other from the observed data (the “evidence” term). For linear reward models, the authors derive closed‑form expressions (Theorem 2) showing that both (p(S_T\mid h)) and each (p(S_{t-1}\mid S_t, h)) are Gaussian with analytically computable means and covariances that blend the diffusion prior’s parameters ((\mu_t, \Sigma_t)) with the empirical sufficient statistics (\bar\theta, \bar\Sigma) obtained from the data.
For generalized linear models (GLMs) where the reward function is non‑linear (e.g., logistic regression), the same structure holds after applying a Laplace approximation to the GLM likelihood. The MAP estimate (\hat\theta) and its Hessian are obtained via Iteratively Reweighted Least Squares (IRLS), and the resulting Gaussian approximation again combines with the diffusion prior in the same way (Theorem 4).
Theoretical analysis (Theorem 3) proves that the proposed approximations are asymptotically consistent: as the number of observations (N) grows, the conditional posteriors concentrate around the true parameter (\theta^\ast). The proof handles the dependence across the (T) stages of the chain and relies on the scaling property of the diffusion process. This consistency is notable because many prior diffusion‑based posterior samplers (e.g., score‑based methods) lack such guarantees.
Algorithmically, LaplaceDPS proceeds as follows: (1) sample (S_T) from a Gaussian that mixes the diffusion prior’s marginal (N(0,I)) with the data‑driven Gaussian (N(\sqrt{\bar\alpha_T},\bar\theta, \bar\alpha_T \bar\Sigma)); (2) for each stage (t = T,\dots,1), sample (S_{t-1}) from a Gaussian whose mean is a weighted sum of the diffusion model’s prediction (\mu_t(S_t)) and the scaled empirical mean (\bar\theta), and whose covariance similarly blends the diffusion covariance (\Sigma_t) with the scaled empirical covariance (\bar\Sigma). The final sample (S_0) is returned as a draw from the posterior over (\theta). The computational cost scales linearly with the number of diffusion steps (T); in practice, setting (T) between 10 and 20 yields a good trade‑off between accuracy and runtime (the authors report sub‑50 ms latency per round).
Empirical evaluation covers several contextual bandit benchmarks, including linear, logistic, and mixed‑feature settings. The authors compare LaplaceDPS against (i) standard Thompson Sampling with a Gaussian prior, (ii) a score‑based diffusion posterior sampler that directly uses the likelihood score, and (iii) an oracle that knows the true prior. Results show that LaplaceDPS consistently achieves higher cumulative reward—often 5–15 % improvement over the Gaussian‑prior baseline—and maintains stable performance even when the true prior is multimodal, a scenario where the Gaussian baseline fails to explore effectively. Moreover, the method remains computationally feasible for online deployment.
In summary, the paper introduces a principled way to incorporate expressive diffusion model priors into Bayesian posterior sampling for contextual bandits. By marrying the diffusion reverse process with Laplace approximations of the likelihood, the authors obtain a tractable, asymptotically sound algorithm that outperforms traditional Gaussian‑prior methods on a range of tasks. The work opens avenues for richer priors in reinforcement learning, Bayesian optimization, and other sequential decision‑making problems where uncertainty quantification is crucial.
Comments & Academic Discussion
Loading comments...
Leave a Comment