Tuning the Implicit Regularizer of Masked Diffusion Language Models: Enhancing Generalization via Insights from $k$-Parity
Masked Diffusion Language Models have recently emerged as a powerful generative paradigm, yet their generalization properties remain understudied compared to their auto-regressive counterparts. In this work, we investigate these properties within the setting of the $k$-parity problem (computing the XOR sum of $k$ relevant bits), where neural networks typically exhibit grokking – a prolonged plateau of chance-level performance followed by sudden generalization. We theoretically decompose the Masked Diffusion (MD) objective into a Signal regime which drives feature learning, and a Noise regime which serves as an implicit regularizer. By training nanoGPT using MD objective on the $k$-parity problem, we demonstrate that MD objective fundamentally alters the learning landscape, enabling rapid and simultaneous generalization without experiencing grokking. Furthermore, we leverage our theoretical insights to optimize the distribution of the mask probability in the MD objective. Our method significantly improves perplexity for 50M-parameter models and achieves superior results across both pre-training from scratch and supervised fine-tuning. Specifically, we observe performance gains peaking at $8.8%$ and $5.8%$, respectively, on 8B-parameter models, confirming the scalability and effectiveness of our framework in large-scale masked diffusion language model regimes.
💡 Research Summary
This paper investigates why Masked Diffusion Language Models (MDLMs) often generalize better than conventional autoregressive models, especially on algorithmic tasks that typically exhibit “grokking” – a long plateau of chance‑level performance followed by sudden generalization. The authors choose the k‑parity problem as a minimal yet challenging testbed. In k‑parity, a secret set of k bits determines the label as the XOR of those bits. The input bits and the label are concatenated into a single sequence, and a one‑layer Transformer (or equivalently a two‑layer MLP after ablation) is trained with a masked diffusion loss.
The loss is defined as the expected mean‑squared error over randomly masked tokens, where the mask probability t is drawn uniformly from an interval
Comments & Academic Discussion
Loading comments...
Leave a Comment