TABES: Trajectory-Aware Backward-on-Entropy Steering for Masked Diffusion Models
Masked Diffusion Models (MDMs) have emerged as a promising non-autoregressive paradigm for generative tasks, offering parallel decoding and bidirectional context utilization. However, current sampling methods rely on simple confidence-based heuristics that ignore the long-term impact of local decisions, leading to trajectory lock-in where early hallucinations cascade into global incoherence. While search-based methods mitigate this, they incur prohibitive computational costs ($O(K)$ forward passes per step). In this work, we propose Backward-on-Entropy (BoE) Steering, a gradient-guided inference framework that approximates infinite-horizon lookahead via a single backward pass. We formally derive the Token Influence Score (TIS) from a first-order expansion of the trajectory cost functional, proving that the gradient of future entropy with respect to input embeddings serves as an optimal control signal for minimizing uncertainty. To ensure scalability, we introduce \texttt{ActiveQueryAttention}, a sparse adjoint primitive that exploits the structure of the masking objective to reduce backward pass complexity. BoE achieves a superior Pareto frontier for inference-time scaling compared to existing unmasking methods, demonstrating that gradient-guided steering offers a mathematically principled and efficient path to robust non-autoregressive generation. We will release the code.
💡 Research Summary
Masked Diffusion Models (MDMs) have emerged as a powerful non‑autoregressive alternative for sequence generation, offering parallel decoding and bidirectional context. However, the prevailing sampling strategies—greedy confidence, margin, or entropy‑based heuristics—select the next tokens solely on their current certainty. This myopic view leads to “trajectory lock‑in”: early unmasking of easy tokens provides little information gain, while structurally crucial tokens (e.g., variables in a reasoning chain, entity mentions in a story) remain masked until later, forcing the model into a globally inconsistent basin that is hard to recover from. Existing remedies such as LookUM simulate multiple future trajectories, but they multiply the forward computation by a factor K, eroding the latency advantage of MDMs.
The authors recast MDM sampling as an entropy‑regularized optimal control problem. At decoding step t, the total masked entropy Hₜ(xₜ)=∑_{i∈Mₜ}H(pᵢₜ) is taken as the state cost. The control action is the set Uₜ of masked positions to unmask (subject to a budget bₜ). The objective is to choose Uₜ that minimizes the expected masked entropy at the next step, i.e., E
Comments & Academic Discussion
Loading comments...
Leave a Comment