Understanding Dynamic Compute Allocation in Recurrent Transformers
Token-level adaptive computation seeks to reduce inference cost by allocating more computation to harder tokens and less to easier ones. However, prior work is primarily evaluated on natural-language benchmarks using task-level metrics, where token-level difficulty is unobservable and confounded with architectural factors, making it unclear whether compute allocation truly aligns with underlying complexity. We address this gap through three contributions. First, we introduce a complexity-controlled evaluation paradigm using algorithmic and synthetic language tasks with parameterized difficulty, enabling direct testing of token-level compute allocation. Second, we propose ANIRA, a unified recurrent Transformer framework that supports per-token variable-depth computation while isolating compute allocation decisions from other model factors. Third, we use this framework to conduct a systematic analysis of token-level adaptive computation across alignment with complexity, generalization, and decision timing. Our results show that compute allocation aligned with task complexity can emerge without explicit difficulty supervision, but such alignment does not imply algorithmic generalization: models fail to extrapolate to unseen input sizes despite allocating additional computation. We further find that early compute decisions rely on static structural cues, whereas online halting more closely tracks algorithmic execution state.
💡 Research Summary
The paper tackles a fundamental gap in token‑level adaptive computation research: the lack of a controlled setting where token difficulty is observable and can be directly compared to the amount of compute a model allocates to each token. To fill this gap, the authors make three major contributions.
First, they construct a suite of algorithmic and synthetic language tasks in which the difficulty of each token is explicitly parameterized. Examples include recursive parenthesis matching, stack‑based arithmetic, and synthetic language generation where the number of required recurrent steps varies deterministically with token position or with a hidden counter. Because the ground‑truth computational demand for each token is known, the authors can measure whether a model’s compute allocation aligns with true complexity.
Second, they introduce ANIRA (Adaptive Neural Iterative Reasoning Architectures), a unified recurrent‑Transformer framework designed specifically for controlled experiments. ANIRA follows a Prelude‑Recurrent‑Coda architecture: a shallow Prelude encodes the input, a shared recurrent core can be iterated up to a maximum depth D, and a Coda maps the final hidden states to next‑token logits. Crucially, the recurrent core is the only part where depth can vary per token, isolating compute‑allocation decisions from other architectural factors. Two decision mechanisms are implemented:
-
ANIRA‑E (early) predicts a token‑specific exit depth from the Prelude representation using a depth‑decider network. The prediction is a categorical distribution over {1,…,D} and is sampled during training via a straight‑through Gumbel‑Softmax estimator.
-
ANIRA‑O (online) makes halting decisions after each recurrent iteration using a halting‑decider that outputs a probability α(d) for each token at step d. The resulting per‑token exit‑depth distribution is obtained by multiplying the remaining probability mass, and discrete depths are sampled with an inverse‑CDF trick, again using a straight‑through estimator.
Both variants are trained with a combined loss L = L_CE + γ L_C, where L_C is a KL‑divergence regularizer that forces the per‑token depth distribution q_i(d) toward a prior p(d) ∝ b⁻ᵈ (b≥1). This regularizer simultaneously penalizes expected depth (via the log‑b term) and encourages high‑entropy, non‑degenerate distributions (via –H(q_i)).
Third, the authors use ANIRA to conduct a systematic empirical study across three dimensions: (1) alignment with task complexity, (2) generalization to unseen input sizes, and (3) the effect of decision timing (early vs. online).
Key findings include:
-
Complexity alignment – Both ANIRA‑E and ANIRA‑O learn to allocate more recurrent steps to tokens that are intrinsically harder, despite receiving no explicit difficulty supervision. ANIRA‑O shows finer‑grained alignment because its halting decisions can condition on intermediate hidden states that reflect the evolving execution state of the algorithm.
-
Generalization – When evaluated on longer sequences or on difficulty configurations not seen during training, models increase their average depth but do not achieve the same accuracy gains as on in‑distribution data. This demonstrates that simply allocating more compute does not guarantee algorithmic generalization; the model still fails to extrapolate the underlying algorithm to larger problem instances.
-
Decision timing – ANIRA‑E’s depth predictions rely heavily on static cues available after the Prelude (e.g., token position, syntactic markers). Consequently, its allocation policy is largely fixed before any iterative reasoning begins. In contrast, ANIRA‑O’s online halting tracks the actual state of the recurrent core (e.g., a counter reaching zero, stack depth), leading to decisions that more closely mirror the true execution progress of the algorithm.
-
Training dynamics – Across experiments, a two‑phase training regime emerges. An initial phase uses relatively high compute to learn the underlying algorithmic pattern; a later phase, driven by the compute regularizer, gradually compresses the policy, reducing average depth while preserving performance on the training distribution.
The paper also introduces practical engineering contributions. An “allocation‑aware KV cache” stores keys and values only up to each token’s exit depth, allowing frozen tokens to remain visible to later attention without incurring unnecessary computation. The overall compute and memory footprint during inference scale with the mean allocated depth (\bar d), yielding a reduction proportional to (\bar d/D) compared with a non‑adaptive model that always runs D steps.
In summary, this work provides a rigorous benchmark for token‑level adaptive computation, demonstrates that complexity‑aligned compute allocation can emerge without supervision, but also shows that such alignment is not sufficient for algorithmic generalization. Moreover, it highlights that the design of the decision mechanism (early static prediction vs. online halting) fundamentally shapes what information the model uses to allocate compute. These insights point toward future research directions: richer state representations for halting, regularizers that explicitly encourage extrapolation, and architectural designs that couple compute allocation more tightly with algorithmic reasoning.
Comments & Academic Discussion
Loading comments...
Leave a Comment