Step-resolved data attribution for looped transformers
We study how individual training examples shape the internal computation of looped transformers, where a shared block is applied for $τ$ recurrent iterations to enable latent reasoning. Existing training-data influence estimators such as TracIn yield a single scalar score that aggregates over all loop iterations, obscuring when during the recurrent computation a training example matters. We introduce \textit{Step-Decomposed Influence (SDI)}, which decomposes TracIn into a length-$τ$ influence trajectory by unrolling the recurrent computation graph and attributing influence to specific loop iterations. To make SDI practical at transformer scale, we propose a TensorSketch implementation that never materialises per-example gradients. Experiments on looped GPT-style models and algorithmic reasoning tasks show that SDI scales excellently, matches full-gradient baselines with low error and supports a broad range of data attribution and interpretability tasks with per-step insights into the latent reasoning process.
💡 Research Summary
**
The paper addresses a fundamental gap in the interpretability of looped (weight‑tied or depth‑recurrent) transformers, a class of models that apply a shared transformer block repeatedly for a fixed or dynamic number of steps τ to enable latent reasoning. While existing training‑data influence estimators such as TracIn provide a single scalar quantifying how a training example affects a test prediction, they aggregate over all loop iterations and therefore hide when during the recurrent computation the influence occurs. This limitation is especially problematic for looped transformers because the loop horizon τ is a test‑time compute knob; practitioners may wish to know whether a training example shapes early processing (e.g., grounding) or later refinement (e.g., iterative solving).
Key Contributions
-
Step‑Decomposed Influence (SDI) – The authors formalize a step‑wise decomposition of TracIn. By unrolling the total loss gradient with respect to the recurrent body parameters w_body into a sum over τ steps and L tokens (Proposition 1), they define a per‑step influence term Iₜ(z, z′) that measures the dot product between the full training‑example gradient and the t‑th step gradient of a test example. Summing Iₜ over t exactly recovers the original TracIn score, establishing a lossless conservation identity. This yields a length‑τ influence trajectory SDI(z, z′) = (I₁,…,I_τ).
-
Sketch‑During‑Backprop Implementation – Computing and storing the full per‑example gradients ϕₜ for every step, checkpoint, and example is infeasible at transformer scale. The authors introduce a memory‑efficient pipeline that sketches gradients on the fly during back‑propagation, never materialising the high‑dimensional vectors. Vector‑valued parameters (biases, layer‑norm scalars) are sketched with CountSketch (CS), while matrix‑valued parameters (attention and MLP weight matrices) are sketched with TensorSketch (TS), which efficiently sketches outer products u ⊗ v by convolving the CountSketches of u and v. Because both CS and TS are linear, the sketched inner products preserve expectations and have provably tighter variance bounds than prior outer‑product sketches. The global sketch Sₘ concatenates independent CS/TS maps per tensor, yielding an α m‑dimensional vector (α = number of parameter tensors).
-
Theoretical Guarantees – The paper proves that TS provides an unbiased estimator of the outer product with variance strictly lower than earlier methods, ensuring accurate influence estimation even with modest sketch dimensions (e.g., m = 1024).
-
Empirical Validation – Experiments span three regimes: (a) recovering finite‑state automata circuits in a parity task, (b) linking test‑time compute scaling to late‑stage influence in Sudoku, and (c) revealing an implicit geometric growth of influence in a 330 M‑parameter looped LLM (NanoChat). Across all settings, SDI matches full‑gradient TracIn with mean absolute error < 0.02 and Pearson correlation > 0.99, while offering per‑step insight unavailable to scalar methods. Notably, SDI uncovers signal cancellation (early positive, late negative influence) and identifies an “influence horizon” beyond which additional loop steps cease to affect the prediction, enabling compute savings.
-
Practical Applications – The step‑resolved trajectory enables (i) calibrating test‑time compute by detecting when influence plateaus, (ii) depth‑targeted data curation (selecting examples that affect specific loop phases), and (iii) fine‑grained model debugging (visualising which training points drive particular reasoning steps).
Methodology Details
- The recurrent body updates hidden states hₜ = F(hₜ₋₁, e_inj; w_body). The loss gradient ∂ℓ/∂w_body is expressed as Σₜ ϕₜ, where each ϕₜ aggregates token‑wise contributions δₜ,ⱼ ⊗ aₜ,ⱼ (for matrix weights) and δₜ,ⱼ (for biases).
- Step‑decomposed influence Iₜ(z, z′) = Σ_k η_k ∇_w_body ℓ(w_k; z)·ϕₜ(w_k; z′).
- Sketching: For each token j and step t, the forward activation aₜ,ⱼ and back‑prop signal δₜ,ⱼ are captured; TS(δₜ,ⱼ, aₜ,ⱼ) is computed and summed across tokens, yielding a sketched version of ϕₜ. The training‑example gradient is similarly sketched via CS/TS during the backward pass of the training step. The final SDI estimate is the inner product of the two sketched vectors, which, by linearity, equals the true inner product in expectation.
Limitations and Future Work
- The current formulation assumes that the recurrent body consists of linear maps (weight matrices plus biases) where per‑example gradients factor as outer products. Extending to more complex non‑linear modules (e.g., gating mechanisms, dynamic routing) may require alternative sketching strategies.
- The approach focuses on static τ; handling dynamically sampled τ or input‑dependent loop depths would involve additional bookkeeping to align step indices across training and test trajectories.
- While the paper demonstrates utility on algorithmic reasoning tasks and a mid‑scale LLM, scaling to multi‑billion‑parameter models and real‑world corpora remains an open engineering challenge.
Impact
SDI provides the first principled, lossless decomposition of training‑data influence across the temporal dimension of recurrent transformer computation. By coupling this analytical tool with a scalable sketch‑based implementation, the authors make step‑resolved data attribution feasible at transformer scale. This opens new avenues for interpretability, data‑centric model improvement, and compute‑efficient inference in the rapidly growing ecosystem of looped and recurrent language models.
Comments & Academic Discussion
Loading comments...
Leave a Comment