Small Generalizable Prompt Predictive Models Can Steer Efficient RL Post-Training of Large Reasoning Models
Reinforcement learning enhances the reasoning capabilities of large language models but often involves high computational costs due to rollout-intensive optimization. Online prompt selection presents a plausible solution by prioritizing informative prompts to improve training efficiency. However, current methods either depend on costly, exact evaluations or construct prompt-specific predictive models lacking generalization across prompts. This study introduces Generalizable Predictive Prompt Selection (GPS), which performs Bayesian inference towards prompt difficulty using a lightweight generative model trained on the shared optimization history. Intermediate-difficulty prioritization and history-anchored diversity are incorporated into the batch acquisition principle to select informative prompt batches. The small predictive model also generalizes at test-time for efficient computational allocation. Experiments across varied reasoning benchmarks indicate GPS’s substantial improvements in training efficiency, final performance, and test-time efficiency over superior baseline methods.
💡 Research Summary
The paper tackles the high computational cost of reinforcement learning with verifiable rewards (RL‑VR) that is used to fine‑tune large language models (LLMs) for reasoning tasks. While RL‑VR improves chain‑of‑thought generation, it requires massive rollouts because each training step must evaluate many prompts, and the usefulness of a prompt varies dramatically: very easy or very hard prompts produce near‑zero gradient signals, whereas intermediate‑difficulty prompts provide the most informative feedback. Existing online prompt‑selection methods mitigate this by either (i) over‑sampling a candidate pool and evaluating each prompt with the current model (e.g., Dynamic Sampling, SPEED‑RL), which adds a large extra rollout burden, or (ii) building prompt‑specific predictive models (PPMs) such as MoPPS that maintain independent Bayesian posteriors per prompt. The latter suffer from cold‑start problems for rarely sampled prompts and from lagged adaptation because they do not share information across prompts nor account for the continuously changing policy parameters.
The authors propose Generalizable Predictive Prompt Selection (GPS), a framework that learns a lightweight, shared generative model of prompt difficulty using the entire optimization history. The key technical contribution is a global latent variable (z_t) (the “difficulty context”) that summarizes all past prompt‑reward pairs (H_{t-1}). A conditional prior (p_\eta(z_t|H_{t-1})) captures the non‑stationary nature of the training dynamics, while a decoder (p_\psi(\gamma| \tau, z_t)) maps a prompt embedding (\tau) and the context to a distribution over the latent success rate (\gamma). Variational inference with an encoder (q_\phi(z_t|H_t)) yields an evidence lower bound (ELBO) that is maximized jointly over the encoder, decoder, and prior. This design enables (1) transfer of difficulty information across semantically related prompts, (2) rapid adaptation to the evolving policy, and (3) prediction for completely unseen prompts by sampling from the prior and passing the samples through the decoder. The authors prove (Theorem 3.1) that conditioning on the full history strictly reduces mean‑squared error compared with prompt‑specific conditioning.
Prompt batch selection in GPS combines two principles: (a) intermediate‑difficulty prioritization, which assigns higher utility to prompts whose predicted success rate lies in a moderate range (e.g., 0.2–0.8), ensuring non‑vanishing gradients; and (b) history‑anchored diversity, which penalizes redundancy by measuring embedding distance and historical correlation among already‑selected prompts. The batch utility is a weighted sum (U(\tau) = (1-\lambda) , f_{\text{difficulty}}(\hat\gamma_\tau) + \lambda , f_{\text{diversity}}(\tau, \mathcal{T}_{\text{partial}})). A greedy algorithm repeatedly selects the prompt with maximal utility until the batch size (B) is reached.
Algorithm 1 outlines the full loop: (1) predict difficulty for every prompt using Monte‑Carlo approximation of the posterior; (2) construct a batch via the utility function; (3) generate (k) responses per prompt, compute binary correctness rewards, and update the LLM policy with the GRPO objective; (4) augment the history with the new reward data; (5) update the PPM by maximizing the ELBO.
Experiments span four reasoning benchmarks (MATH, GSM‑8K, Codeforces, LogicalDeduction) and three LLM backbones (LLaMA‑2‑7B, LLaMA‑13B, GPT‑NeoX‑20B). Baselines include random sampling, Dynamic Sampling, SPEED‑RL, MoPPS, and evaluation‑based selection. Results show:
- Difficulty prediction: GPS reduces mean‑squared error by ~30 % relative to MoPPS.
- Training efficiency: For the same total number of rollouts, GPS achieves 1.6×–2.0× faster convergence and improves final accuracy by 1–3 percentage points over random and by 0.8 pp over DS.
- Cost vs. evaluation‑based selection: GPS attains comparable performance while using only 31 % of the extra rollouts required by evaluation‑based methods (≈ 69 % reduction).
- Test‑time compute allocation: The learned PPM can re‑allocate a fixed inference budget across prompts; under identical budgets, accuracy improves by 2.1–3.2 pp, or inference cost drops by up to 36.4 % with negligible accuracy loss (<0.3 pp).
The paper demonstrates that a small, generalizable predictive model can dramatically steer the expensive RL fine‑tuning of large reasoning models. By sharing information across prompts and explicitly balancing difficulty with diversity, GPS reduces both training and inference costs while consistently boosting performance. Limitations include the use of a single global latent vector (potentially insufficient for highly heterogeneous prompt spaces), sensitivity to the number of Monte‑Carlo samples in the prediction step, and the additional engineering overhead of maintaining the variational model during large‑scale online training. Future work could explore multi‑latent or hierarchical contexts, incorporate meta‑features (e.g., prompt topic, length), compare against non‑Bayesian ensembles, and evaluate robustness in real‑world deployment pipelines.
In summary, GPS offers a principled, cost‑effective alternative to existing prompt‑selection strategies, confirming that “small models can guide large models” and opening a promising direction for scalable, efficient LLM post‑training.
Comments & Academic Discussion
Loading comments...
Leave a Comment