Diffusion Alignment Beyond KL: Variance Minimisation as Effective Policy Optimiser
Diffusion alignment adapts pretrained diffusion models to sample from reward-tilted distributions along the denoising trajectory. This process naturally admits a Sequential Monte Carlo (SMC) interpretation, where the denoising model acts as a proposal and reward guidance induces importance weights. Motivated by this view, we introduce Variance Minimisation Policy Optimisation (VMPO), which formulates diffusion alignment as minimising the variance of log importance weights rather than directly optimising a Kullback-Leibler (KL) based objective. We prove that the variance objective is minimised by the reward-tilted target distribution and that, under on-policy sampling, its gradient coincides with that of standard KL-based alignment. This perspective offers a common lens for understanding diffusion alignment. Under different choices of potential functions and variance minimisation strategies, VMPO recovers various existing methods, while also suggesting new design directions beyond KL.
💡 Research Summary
The paper tackles the problem of diffusion alignment, i.e., adapting a pretrained diffusion model so that its samples are biased toward high‑reward regions of the data space. Existing works cast this as a KL‑regularized reinforcement‑learning problem, where the denoising policy is updated to minimise the KL divergence between the current policy and a reward‑tilted target distribution. While effective, the KL perspective does not directly address the sampling efficiency of the underlying Sequential Monte Carlo (SMC) process that naturally arises when the denoising model is viewed as a proposal and the reward guidance as importance weights.
The authors reinterpret diffusion alignment through the lens of SMC. For each timestep t, the proposal is pθ(x_{t‑1}|x_t) and the unnormalised importance weight is w_t = p_ref(x_{t‑1}|x_t)·exp(r(x_{t‑1})/β) / pθ(x_{t‑1}|x_t). The target distribution is p_tilt(x_{t‑1}|x_t) ∝ p_ref·exp(r/β). Instead of directly minimising KL(pθ‖p_tilt), they propose to minimise the variance of the log‑importance weights, defining a loss L_h^Var(t;θ)=½Var_h(log w_t), where h is any reference distribution sharing the support of pθ and p_ref (typically set to pθ for on‑policy updates).
Proposition 1 shows that the global optimum of this variance loss is exactly p_tilt, and that the gradient of L_h^Var with respect to θ evaluated at h=pθ coincides with the gradient of the KL objective. Consequently, under on‑policy sampling, the variance‑minimisation objective is mathematically equivalent to the standard KL‑regularised alignment, but it offers a more direct interpretation in terms of sampling efficiency: minimising variance reduces particle degeneracy in the SMC view.
Two practical estimators are introduced. The first uses straightforward Monte‑Carlo samples to compute the empirical variance; despite the bias introduced by the quadratic term, its gradient remains an unbiased estimator of –∇θ J_KL. The second amortises the expectation E_h
Comments & Academic Discussion
Loading comments...
Leave a Comment