Diffusion-State Policy Optimization for Masked Diffusion Language Models
Masked diffusion language models generate by iteratively filling masked tokens over multiple denoising steps, so learning only from a terminal reward on the final completion yields coarse credit assignment over intermediate decisions. We propose DiSPO (Diffusion-State Policy Optimization), a plug-in credit-assignment layer that directly optimizes intermediate filling decisions. At selected intermediate masked states, DiSPO branches by resampling fillings for the currently masked positions from rollout-cached logits, scores the resulting completions, and updates only the newly filled tokens – without additional multi-step diffusion rollouts. We formalize a fixed-state objective for branched completions and derive a policy-gradient estimator that can be combined with terminal-feedback policy optimization using the same rollouts. On LLaDA-8B-Instruct, DiSPO consistently improves over the terminal-feedback diffu-GRPO baseline on math and planning benchmarks under matched rollout compute and optimizer steps. Our code will be available at https://daioba.github.io/dispo .
💡 Research Summary
This paper addresses a fundamental limitation of policy optimization for masked diffusion language models (MDLMs): the coarse credit assignment that arises when learning solely from a scalar reward evaluated on the final completion. In standard terminal‑feedback approaches such as diffu‑GRPO, the reward is broadcast uniformly across all decoding decisions, which obscures which intermediate mask‑filling actions actually contributed to success, especially in tasks that require multi‑step reasoning or planning.
The authors propose Diffusion‑State Policy Optimization (DiSPO), a plug‑in layer that introduces fine‑grained credit assignment at selected intermediate diffusion states. At a given denoising step t, the partially‑masked sequence xₖ,ₜ defines a state sₖ,ₜ = (prompt q, xₖ,ₜ). The set of currently masked positions Mₖ,ₜ constitutes the action space; the model’s per‑position logits define a factorized policy πθ(·|sₖ,ₜ). DiSPO branches from this state by resampling Z alternative fillings aₖ,ₜ,₁…aₖ,ₜ,𝑍 using the logits that were already cached during the original rollout. Each sampled action is turned into a deterministic completion oₖ,ₜ,𝑧 = FILL(xₖ,ₜ, aₖ,ₜ,𝑧), which replaces only the masked tokens while leaving the rest of the sequence untouched.
All completions are scored with the same terminal reward function R(q, o), yielding rewards Rₖ,ₜ,𝑧. A within‑group baseline \bar{R}ₖ,ₜ = (1/Z)∑𝑧 Rₖ,ₜ,𝑧 is subtracted to form advantages Aₖ,τ,𝑧 = Rₖ,τ,𝑧 – \bar{R}ₖ,τ. The policy gradient for the intermediate decision is then estimated with a standard likelihood‑ratio term: ρₖ,τ,𝑧(θ) = exp
Comments & Academic Discussion
Loading comments...
Leave a Comment