Taming the Tail: Stable LLM Reinforcement Learning via Dynamic Vocabulary Pruning

Reading time: 11 minute
...

📝 Original Info

  • Title: Taming the Tail: Stable LLM Reinforcement Learning via Dynamic Vocabulary Pruning
  • ArXiv ID: 2512.23087
  • Date: 2025-12-28
  • Authors: Yingru Li, Jiawei Xu, Jiacai Liu, Yuxuan Tong, Ziniu Li, Tianle Cai, Ge Zhang, Qian Liu, Baoxiang Wang

📝 Abstract

Reinforcement learning for large language models (LLMs) faces a fundamental tension: high-throughput inference engines and numerically-precise training systems produce different probability distributions from the same parameters, creating a training-inference mismatch. We prove this mismatch has an asymmetric effect: the bound on log-probability mismatch scales as (1 -p) where p is the token probability. For high-probability tokens, this bound vanishes, contributing negligibly to sequence-level mismatch. For low-probability tokens in the tail, the bound remains large, and moreover, when sampled, these tokens exhibit systematically biased mismatches that accumulate over sequences, destabilizing gradient estimation. Rather than applying post-hoc corrections, we propose constraining the RL objective to a dynamically-pruned "safe" vocabulary that excludes the extreme tail. By pruning such tokens, we trade large, systematically biased mismatches for a small, bounded optimization bias. Empirically, our method achieves stable training; theoretically, we bound the optimization bias introduced by vocabulary pruning.

📄 Full Content

1 Introduction

Reinforcement learning has emerged as a key technique for training large language models on complex reasoning and multi-turn agentic tasks, where outcome-based rewards provide the primary learning signal. However, applying RL to LLMs at scale faces a critical computational bottleneck: rollout generation. Producing the large number of sample trajectories needed to estimate policy gradients requires high throughput. Modern inference engines (e.g., vLLM [4], SGLang [8]) achieve this through aggressive optimizations including paged attention, low-precision KV-cache (INT8/FP8), and fused CUDA kernels-all designed to maximize tokens per second.

Meanwhile, training systems (e.g., FSDP, Megatron-LM) must prioritize numerical stability and gradient precision, typically operating at higher precision (FP32 or mixed precision with careful accumulation). This creates a training-inference mismatch: the inference policy π infer θ used to sample trajectories differs subtly but systematically from the training policy π train θ used to compute gradients.

One might propose enforcing identical computations across both systems. However, this defeats the purpose of using a high-speed inference engine, potentially reducing throughput by orders of magnitude. The speed-versus-consistency tradeoff appears fundamental: as inference engines become faster through more aggressive optimization, this gap will only widen. Training instability is therefore not a transient implementation bug but a persistent challenge inherent to the modern LLM-RL stack.

This work takes a principled stance: we view training instability not as a technical bug requiring reactive correction, but as a symptom of a poorly-specified learning objective. Specifically, any objective requiring accurate gradient estimation over the extremely low-probability tail of a 100,000+ token vocabulary is fragile because the log-probability mismatch bound does not shrink for low-probability tokens as it does for high-probability tokens.

We propose a different approach: redesign the learning objective itself to operate only over a dynamically-pruned “safe” vocabulary at each generation step. This achieves stability by excluding the problematic tail from the objective, rather than applying reactive patches or adhoc clipping after the fact.

Our work makes three primary contributions:

  1. Rigorous diagnosis: We characterize the mathematical structure of instability in offpolicy gradient estimation for LLMs. We prove that vulnerability is asymmetric-the log-probability mismatch bound scales as (1 -p), vanishing for high-probability tokens but remaining significant for low-probability tokens. Moreover, we show that sampled lowprobability tokens have systematically biased mismatches that accumulate over sequences (Section 3).

We propose dynamic vocabulary pruning using min-p filtering [6] to define constrained policies. This addresses the source of instability at the objective level rather than through post-hoc corrections (Section 4).

We demonstrate that our method achieves stable training and significant performance improvements on mathematical reasoning tasks (Section 5).

We formalize autoregressive text generation as a Markov Decision Process (MDP).

Definition 2.1 (LLM Generation MDP). The generation process is defined by:

• State: s t = [x; y 1 , . . . , y t-1 ] is the prompt x concatenated with tokens generated so far.

• Action: a ∈ V is a token from the vocabulary.

• Policy: π θ (a|s t ) is the LLM’s next-token distribution.

• Trajectory: y = (y 1 , . . . , y T ) is a complete generation.

• Reward: R(x, y) ∈ {0, 1} is typically the correctness of the solution, applicable to both single-turn reasoning and multi-turn agentic tasks.

Standard RL objective. The goal is to maximize expected reward:

Policy gradient. The gradient is:

where log π θ (y|x) = T t=1 log π θ (y t |s t ) by the chain rule. The challenge arises when we must sample from one policy but compute gradients with respect to another-the off-policy setting underlying the training-inference mismatch.

Training instability in LLM reinforcement learning has deep mathematical roots. We diagnose this instability by analyzing the training-inference mismatch scenario.

(y)-the inverse ratio-which becomes vanishingly small for the same problematic trajectories, causing high variance. Either way, large probability ratios destabilize gradient estimation.

To analyze which tokens are vulnerable, we must understand how the training-inference mismatch manifests at the token level. Even with identical parameters θ, the two systems produce different logits due to a fundamental property of floating-point arithmetic: non-associativity. We have (a ⊕ b) ⊕ c ̸ = a ⊕ (b ⊕ c) in finite precision, so different computation orders yield different results [3].

In practice, inference engines (vLLM [4], SGLang [8]) and training frameworks (Megatron-LM, FSDP) differ in multiple ways: (1) attention implementations-PagedAttention [4] vs. FlashAttention-2 [2] use different reduction orders for the softmax denominator

(2) numerical precision-FP8/INT8 KV-cache quantization vs. BF16/FP32 accumulation; (3) operator fusion-different kernel boundaries change intermediate rounding. We model the aggregate effect as:

represents the perturbation vector. Since these numerical errors arise from bounded-precision arithmetic, the perturbations satisfy |ε k | ≤ ϵ max for some small ϵ max .

With this perturbation model, we can now characterize which tokens are most vulnerable. Crucially, vulnerability is not uniform across the vocabulary.

Under the logit perturbation model z infer = z train + ε with |ε k | ≤ ϵ max , the token-level log-probability mismatch satisfies:

where p a = π train θ (a|s) is the training policy probability.

Proof in Appendix A.2. This reveals the asymmetric structure: high-probability tokens (p a → 1) have (1 -p a ) → 0, so the bound vanishes; low-probability tokens (p a → 0) have (1 -p a ) ≈ 1, so the bound remains at its maximum value 2ϵ max .

To understand the typical magnitude of mismatch (not just the worst case), we model the perturbations as i.i.d. with mean zero and variance σ 2 . Proposition 3.3 (Signature of Failure). Under the perturbation model with ε k iid ∼ (0, σ 2 ), given that action a is sampled from π infer θ , the mode of the mismatch ∆ ′ a = -∆ a is approximately:

where p a = π train θ (a|s) and p ′ a = π infer θ (a|s).

Proof in Appendix A.3. For high-probability tokens, the mode is near zero (benign mismatch). For low-probability tokens, the mode is strictly positive, implying the probability ratio π infer θ /π train θ is systematically inflated. This theoretical prediction aligns with prior empirical observations: Liu et al. [5] found that sampled low-probability tokens exhibit π infer θ ≫ π train θ in practice, contributing to training collapse.

The diagnosis is clear: (1) the vocabulary tail is a region of high instability risk; (2) vulnerability is asymmetric-the mismatch bound vanishes for high-probability tokens but remains at 2ϵ max for low-probability tokens; (3) when a low-probability token is sampled, the token-level mismatch ∆ a tends to be negative (Proposition 3.3), meaning π infer θ (a|s) ≫ π train θ (a|s). Crucially, these per-token mismatches accumulate over sequences: ∆ y = t ∆ yt . Sequences containing many low-probability tokens therefore have systematically negative ∆ y , leading to large sequence-level probability ratios exp(-∆ y ) ≫ 1. This motivates a solution that excludes the tail from the learning objective.

We pursue objective redesign over reactive patching: constrain the learning objective to a dynamically-pruned “safe” vocabulary.

Min-p sampling [6] retains tokens whose probability exceeds a fraction ρ of the maximum probability. We adapt this for defining safe action sets. Definition 4.1 (Min-P Safe Action Sets). Given threshold ρ ∈ (0, 1], the safe action sets are:

The threshold ρ is typically extremely small (e.g., ρ = e -13 ≈ 2.3 × 10 -6 ), retaining a broad set of plausible tokens while pruning only the extreme tail.

Definition 4.2 (Min-P Constrained Policies). The constrained policies are:

where Z θ (s) = k∈V S (s) π train θ (k|s) and Z ′ θ (s) = k∈V ′ S (s) π infer θ (k|s).

Our constrained RL objective is J mp (θ) = E y∼π train mp [R(x, y)]-a different objective from J(θ) that avoids the unstable tail by design.

well-defined whenever y t ∈ V S (s t ) for all t.

Remark 4.4 (Support Condition). When y t ∈ V ′ S (s t ) but y t / ∈ V S (s t ), the probability ratio π train mp /π infer mp is zero-no bias, just wasted samples. The converse case (bias-inducing) is rare: by Proposition 3.2, high-probability tokens have small |∆ a |, so π train θ (a|s) ≈ π infer θ (a|s), ensuring V S (s) ≈ V ′ S (s) for tokens that matter. Remark 4.5 (Fixed Safe Sets in Gradient Computation). We treat V S (s) as fixed during backpropagation (via torch.no_grad()), a standard approximation that introduces negligible error. See Appendix C for implementation details.

By constraining the objective to J mp instead of J, we trade a small optimization bias for stable gradient estimation. The key benefit of vocabulary pruning is that it excludes tokens where the mismatch ∆ a is systematically biased. By Proposition 3.3, when a low-probability token is sampled from π infer θ , the mismatch ∆ a tends to be negative, meaning π infer θ (a|s) ≫ π train θ (a|s). These systematically negative mismatches accumulate over sequences: if many tokens have ∆ yt < 0, then ∆ y = t ∆ yt ≪ 0, causing exp(-∆ y ) ≫ 1. By excluding the extreme tail, we avoid sampling tokens with systematically biased mismatches, preventing this accumulation.

The optimization bias is bounded (proof in Appendix B):

where Z min = min s Z θ (s) is the minimum retained probability mass. With ρ = e -13 , we have Z θ (s) ≈ 1 in nearly all contexts, making the optimization bias negligible.

We evaluate Dynamic Vocabulary Pruning (DVP) on the mathematical reasoning task, employing the RLOO [1] as the base algorithm. For our experimental setup, we utilize the filtered DAPO dataset 1 for training and assess performance on the AIME25. We conduct full on-policy training using the Qwen3-14B-Base, with both the rollout batch size and mini-update size set to 32. The maximum response length is 16,384, and the group size is 16. For our DVP, we employ a min-p threshold of ρ = e -13 . For importance sampling, we adopt tokenlevel Truncated Importance Sampling (TIS) [7] and Masked Importance Sampling (MIS) [5]. To mitigate variance and ensure reproducibility, we report avg@16 scores in Figure 1. As shown in Figure 1, naive RLOO suffers from early collapse due to a massive traininginference PPL gap. While TIS attempts to mitigate this instability, it still exhibits a substantial PPL gap and fails to achieve competitive results. With DVP, the PPL gap remains stable throughout training, yielding significantly higher scores. Notably, the combination of MIS and DVP achieves a 26.55% improvement over naive RLOO’s peak performance.

We analyzed training instability in LLM reinforcement learning, showing that it arises from distributional mismatch between inference and training systems. The vulnerability is asymmetric: the log-probability mismatch bound scales as (1 -p), vanishing for high-probability tokens but remaining large for low-probability tokens. Moreover, sampled low-probability tokens have systematically biased mismatches that accumulate over sequences, causing sequence-level probability ratios to grow large. Rather than applying post-hoc corrections, we propose dynamic vocabulary pruning-constraining the objective to a “safe” vocabulary that excludes the extreme tail. This avoids tokens with systematically biased mismatches at the cost of a small, bounded optimization bias. Our approach offers a principled path toward stable reinforcement learning for LLMs.

By the Mean Value Theorem:

The gradient is ∂fa ∂z k = δ ak -p k (z c ). Therefore:

By the triangle inequality:

A.3 Proof of Proposition 3.3: Signature of Failure

Proof. Let E a denote “action a sampled from π infer θ .” Using Bayes’ theorem with Gaussian prior on perturbations:

Setting derivatives to zero:

The constrained policy π train mp involves selecting the safe set V S (s), a non-differentiable operation. We show that simple logit masking correctly implements the required gradient.

Define masked logits:

Proposition C.1 (Masked Logit Correctness). For all a ∈ V S (s):

  1. softmax(z mp ) a = π train mp (a|s). Apply min -p masking to logits .

https://huggingface.co/datasets/Jiawei415/DPAO_filter/tree/main/train

Reference

This content is AI-processed based on open access ArXiv data.

Start searching

Enter keywords to search articles

↑↓
ESC
⌘K Shortcut