Just on Time: Token-Level Early Stopping for Diffusion Language Models
Diffusion language models generate text through iterative refinement, a process that is often computationally inefficient because many tokens reach stability long before the final denoising step. We introduce a training-free, token-level early stopping approach that identifies convergence independently at each position. Our method leverages lightweight signals derived from the model’s predictions and local context to dynamically determine when individual tokens can be finalized. This yields adaptive per-token freezing without task-specific fine-tuning, substantially reducing the total number of diffusion steps required. Across diverse benchmarks, spanning mathematical reasoning, general question answering, and scientific understanding, our approach achieves state-of-the-art efficiency gains while preserving generation quality.
💡 Research Summary
Diffusion language models (DLMs) generate text by iteratively denoising a fully masked sequence, allowing parallel token prediction and bidirectional context. However, the reverse‑diffusion process typically requires many refinement steps (T), even though many tokens become stable far earlier than the final step. This inefficiency motivates the paper’s central contribution: JOT (Just on Time), a training‑free, token‑level early‑stopping mechanism that decides, at each diffusion step, whether an individual token’s prediction is confident enough to be finalized.
Core ideas
- Confidence metric – For each masked position i, the model outputs logits ℓ_i, which are transformed into a probability distribution π_i via a softmax without temperature scaling. The top‑1 probability p_i1 and the second‑best p_i2 are extracted, and a confidence score r_i = p_i1 / (p_i2 + ε) is computed. A high ratio indicates that the model strongly favors a single token, suggesting convergence.
- Spatial modulation – Tokens adjacent to already‑unmasked positions enjoy richer context, so the required confidence can be lowered. A geometric kernel with decay γ over a window of radius D yields a weight w_i = Σ_{|i−j|≤D} γ^{|i−j|} for each masked token i. Normalizing this weight gives a spatial factor ϕ_i ∈
Comments & Academic Discussion
Loading comments...
Leave a Comment