A State-Transition Framework for Efficient LLM Reasoning
While Long Chain-of-Thought (CoT) reasoning significantly improves Large Language Models (LLMs) performance on complex reasoning tasks, the substantial computational and memory costs of generating long CoT sequences limit their efficiency and practicality. Existing studies usually enhance the reasoning efficiency of LLMs by compressing CoT sequences. However, this approach conflicts with test-time scaling, limiting the reasoning capacity of LLMs. In this paper, we propose an efficient reasoning framework that models the reasoning process of LLMs as a state-transition process. Specifically, we first apply a linear attention mechanism to estimate the LLM’s reasoning state, which records the historical reasoning information from previous reasoning steps. Then, based on the query prompt and the reasoning state, the LLM can efficiently perform the current reasoning step and update the state. With the linear attention, each token in the current reasoning step can directly retrieve relevant historical reasoning information from the reasoning state, without explicitly attending to tokens in previous reasoning steps. In this way, the computational complexity of attention is reduced from quadratic to linear, significantly improving the reasoning efficiency of LLMs. In addition, we propose a state-based reasoning strategy to mitigate the over-thinking issue caused by noisy reasoning steps. Extensive experiments across multiple datasets and model sizes demonstrate that our framework not only improves the reasoning efficiency of LLMs but also enhances their reasoning performance.
💡 Research Summary
The paper tackles the prohibitive computational and memory costs associated with generating long Chain‑of‑Thought (CoT) sequences in large language models (LLMs). Traditional Transformers compute softmax attention with quadratic complexity O(C²) and store a KV‑cache that grows linearly with context length C, making long CoT generation impractical. Existing efficiency methods either compress the CoT text or force the model to produce shorter reasoning, which compromises the test‑time scaling benefits of longer CoT and can degrade reasoning ability.
The authors propose a novel “state‑transition” framework that treats the reasoning process as a sequence of discrete steps, each consisting of linguistic surface text and a compact set of reasoning information needed for later steps. They introduce a reasoning state matrix Sₜ that aggregates the essential reasoning information from all completed steps using a linear attention mechanism. Linear attention replaces the softmax exponential with a kernel ϕ(·) (often the identity), allowing the attention computation to be expressed as qₜ·Sₜ, where Sₜ = Σₖᵀvᵢ is updated incrementally. This reduces attention complexity from O(C²) to O(C) and collapses the KV‑cache memory to a constant‑size matrix, independent of the total number of tokens generated.
To integrate linear attention without sacrificing the expressive power of softmax attention, the paper introduces a Mixed Attention Module (MAM). MAM consists of two parallel sub‑modules:
- Softmax‑Attention (SA) sub‑module – identical to the original Transformer attention, but its KV‑cache is limited to the query prompt and the tokens of the current reasoning step. This preserves fluency and local context while eliminating the need to keep all past tokens.
- Linear‑Attention (LA) sub‑module – computes the reasoning state Sₜ and extracts historical reasoning information via oₜ = qₜ·Sₜ. A gating mechanism (σ(W_g·h_SA)) controls the contribution of LA, giving it higher weight early in a step when the model relies more on past reasoning, and gradually reducing its influence as the step proceeds.
The outputs of SA and LA are summed, passed through a linear projection, and then through the usual feed‑forward network. The LA sub‑module is implemented with LoRA adapters to keep parameter overhead low.
A second major contribution is a state‑based reasoning strategy to mitigate “over‑thinking” caused by noisy intermediate steps. Since the linear‑attention state update can be interpreted as a stochastic gradient descent (SGD) step on a simple loss L(S) = –⟨S·kₜ, vₜ⟩, each reasoning step yields a gradient ∇ₜ = Sₜ – Sₜ₋₁. The authors accumulate these gradients across steps using a momentum update to obtain a global gradient G, which represents the overall reasoning direction. During token generation, the model incorporates qₜ·G as an additional guidance signal, steering the current step toward the global direction and preventing divergence caused by erroneous steps.
Experimental Evaluation
The authors evaluate on seven public benchmarks covering mathematics, logical reasoning, and code generation, using three model sizes (7B, 13B, 34B). Baselines include standard CoT, compressed CoT, and recent efficient attention variants such as FlashAttention. Results show:
- Speedup: 2.3×–3.1× faster inference across all settings.
- Memory reduction: KV‑cache usage drops to <30 % of the original.
- Accuracy gain: 1.2–2.5 percentage points improvement over standard CoT. Ablation studies reveal that removing the LA sub‑module or the momentum‑based global gradient harms performance, confirming the importance of both components. Varying the dimensionality of Sₜ shows that a moderate size (256–512) balances compression loss and expressive power.
Strengths and Contributions
- Recasting reasoning as a state‑transition process enables linear‑time attention and constant‑memory inference without shortening CoT.
- The Mixed Attention Module seamlessly blends softmax and linear attention, preserving fluency while exploiting the efficiency of state‑based retrieval.
- The global‑gradient guidance effectively curbs over‑thinking, a novel use of the test‑time training perspective of linear attention.
- Comprehensive empirical validation across model scales and tasks demonstrates simultaneous efficiency and performance gains.
Limitations and Future Work
The approach relies on a fixed‑size state matrix; choosing its dimension and kernel function ϕ can affect performance, and very deep multi‑step reasoning may still suffer from cumulative compression loss. The method currently focuses on text‑only CoT; extending to multimodal reasoning or integrating dynamic state‑size adaptation are promising directions. Moreover, exploring reinforcement‑learning‑based optimization of the state‑transition policy could further enhance robustness.
In summary, the paper presents a well‑motivated, technically sound framework that advances the state of the art in efficient LLM reasoning by introducing a linear‑attention‑driven state representation and a gradient‑guided reasoning strategy, achieving notable speed, memory, and accuracy improvements.
Comments & Academic Discussion
Loading comments...
Leave a Comment