Inference-Time Rethinking with Latent Thought Vectors for Math Reasoning
Standard chain-of-thought reasoning generates a solution in a single forward pass, committing irrevocably to each token and lacking a mechanism to recover from early errors. We introduce Inference-Time Rethinking, a generative framework that enables iterative self-correction by decoupling declarative latent thought vectors from procedural generation. We factorize reasoning into a continuous latent thought vector (what to reason about) and a decoder that verbalizes the trace conditioned on this vector (how to reason). Beyond serving as a declarative buffer, latent thought vectors compress the reasoning structure into a continuous representation that abstracts away surface-level token variability, making gradient-based optimization over reasoning strategies well-posed. Our prior model maps unstructured noise to a learned manifold of valid reasoning patterns, and at test time we employ a Gibbs-style procedure that alternates between generating a candidate trace and optimizing the latent vector to better explain that trace, effectively navigating the latent manifold to refine the reasoning strategy. Training a 0.2B-parameter model from scratch on GSM8K, our method with 30 rethinking iterations surpasses baselines with 10 to 15 times more parameters, including a 3B counterpart. This result demonstrates that effective mathematical reasoning can emerge from sophisticated inference-time computation rather than solely from massive parameter counts.
💡 Research Summary
Standard chain‑of‑thought (CoT) prompting generates a reasoning trace in a single forward pass, irrevocably committing to each token. Early mistakes therefore propagate unchecked, and the model has no mechanism to backtrack. This paper proposes a cognitively inspired separation between declarative knowledge (what to think) and procedural execution (how to express), realized through latent thought vectors (z). A prior encoder maps random Gaussian noise z₀ into a structured continuous vector z, which serves as a global conditioning signal for a Transformer decoder. The decoder cross‑attends to z at every layer, while its own self‑attention window is deliberately limited (w = 64 tokens) so that long‑range logical structure must be encoded in z rather than in local token context.
Training follows a variational Bayes approach. For each training example a non‑amortized Gaussian posterior q(z₀)=N(μ,diag(σ²)) is introduced and directly optimized together with the global parameters θ = (α,β). The ELBO consists of the expected log‑likelihood of the observed question‑answer pair given z and a KL regularizer toward the standard normal prior. Optimization proceeds on two timescales: fast local updates of (μ,σ²) using a high learning rate (≈0.3) for a small number of steps (T_fast = 16), and slow global updates of the encoder and decoder parameters with a modest learning rate (≈4×10⁻⁴). This dual‑rate scheme enables rapid per‑instance adaptation of the latent thought while still accumulating shared knowledge across the dataset, and avoids posterior collapse that often plagues VAEs for language.
Inference introduces “Inference‑Time Rethinking”, a Gibbs‑style iterative loop. Starting from an initial latent inferred from the question alone, the model alternates: (1) Generate a candidate reasoning trace xᵣ conditioned on the current z via the decoder; (2) Reflect by optimizing z₀ (or equivalently μ,σ²) to maximize the ELBO for the observed (question, generated trace). The second step effectively moves the latent vector toward a region of the learned manifold that better explains the trace, allowing the model to correct earlier mistakes. After T_rethink iterations (30 in the main experiments) the trace with the highest likelihood is selected as the final answer.
Experiments are conducted on three math reasoning benchmarks: GSM8K (in‑domain), SV‑AMP (out‑of‑domain wording variations), and MultiArith (out‑of‑domain multi‑step problems). All models are trained from scratch on an augmented GSM8K dataset (≈385 K examples) without any pre‑training. The proposed 0.2 B‑parameter model (Llama‑2‑style decoder + 2‑layer encoder with 64 latent tokens) is compared against much larger baselines (0.5 B and 3 B Qwen2.5 backbones) that use standard CoT fine‑tuning or other latent‑reasoning methods (iCoT‑SI, Coconut, CoLaR, CODI, MARCoS). Results show that even a single‑pass version (Rethink‑1) already outperforms most baselines, achieving 25.93 % on GSM8K versus 20–22 % for 3 B CoT‑SFT. With 30 rethinking iterations (Rethink‑30) the model reaches 31.54 % on GSM8K, 51.50 % on SV‑AMP, and 68.00 % on MultiArith, setting new state‑of‑the‑art across all three datasets despite being an order of magnitude smaller.
The authors attribute these gains to three factors: (i) latent thought vectors compress global reasoning structure, freeing the decoder from memorizing every pattern; (ii) the learned transport‑map prior shapes a manifold of valid reasoning strategies, making gradient‑based refinement effective; (iii) inference‑time computation provides a new scaling axis orthogonal to parameter count. Ablation analysis of the gap between Rethink‑1 and Rethink‑30 quantifies the benefit of iterative refinement (≈5 % absolute improvement on GSM8K).
Limitations are discussed. The approach relies on high‑quality, near‑expert demonstrations because likelihood is used as a proxy for correctness; noisy supervision could cause the optimizer to gravitate toward flawed reasoning patterns. To mitigate this, future work may incorporate a latent verifier that predicts correctness directly from z before any tokens are generated, turning rethinking into latent planning. Additionally, external symbolic feedback (e.g., code execution, theorem provers) could be used for test‑time policy‑gradient updates of the latent vector.
In summary, the paper demonstrates that sophisticated inference‑time computation—implemented as iterative optimization of declarative latent thoughts—can compensate for modest model size and achieve or surpass the performance of much larger language models on mathematical reasoning tasks. This establishes “thinking longer at test time” as a viable scaling strategy and opens avenues for richer latent‑space planning and verification in future LLM reasoning systems.
Comments & Academic Discussion
Loading comments...
Leave a Comment