State-Space Models for Tabular Prior-Data Fitted Networks
Recent advancements in foundation models for tabular data, such as TabPFN, demonstrated that pretrained Transformer architectures can approximate Bayesian inference with high predictive performance. However, Transformers suffer from quadratic complexity with respect to sequence length, motivating the exploration of more efficient sequence models. In this work, we investigate the potential of using Hydra, a bidirectional linear-time structured state space model (SSM), as an alternative to Transformers in TabPFN. A key challenge lies in SSM’s inherent sensitivity to the order of input tokens - an undesirable property for tabular datasets where the row order is semantically meaningless. We investigate to what extent a bidirectional approach can preserve efficiency and enable symmetric context aggregation. Our experiments show that this approach reduces the order-dependence, achieving predictive performance competitive to the original TabPFN model.
💡 Research Summary
The paper investigates replacing the Transformer backbone of Tabular Prior‑Data Fitted Networks (TabPFN) with structured state‑space models (SSMs) that operate in linear time, focusing on two variants: Mamba and its bidirectional extension Hydra. TabPFN, a foundation model for tabular classification, achieves near‑Bayesian inference by training a Transformer on millions of synthetic tasks and then performing a single forward pass on a new dataset. While this yields impressive few‑shot performance, the self‑attention mechanism incurs O(N²) memory and compute costs with respect to the number of rows N, limiting practical use to tables of a few thousand rows.
Mamba addresses the quadratic bottleneck by representing the sequence through recurrent state updates and quasi‑separable matrix multiplications, achieving O(N) complexity. However, Mamba is causal: it processes the sequence in a single direction, making its representations highly sensitive to the order of rows. In tabular data, row order is semantically meaningless, so this sensitivity can degrade performance and stability.
Hydra builds on Mamba by adding a bidirectional mixing layer that simultaneously incorporates forward and backward context using quasi‑separable matrix mixers. This bidirectional design reduces order dependence while preserving the linear‑time advantage. The authors replace the Transformer encoder in TabPFN with a stack of Hydra layers, keep the original embedding scheme (concatenating feature values with class labels), and retrain the model on the same synthetic task distribution used for the original TabPFN. No other changes to the training pipeline are required.
To further mitigate order sensitivity, the authors propose Repeated Context Permutations (RCP). During inference, the entire table (including the test instance) is randomly permuted r times; the model predicts on each permuted version, and the resulting probability vectors are averaged. This simple ensemble‑like technique increases inference time linearly with r but empirically reduces the KL‑divergence between predictions on different permutations and modestly improves accuracy.
Experiments are conducted on 30 multiclass classification datasets from the OpenML CC‑18 benchmark, each filtered to ≤ 2000 rows, ≤ 100 features, and ≤ 10 classes, with 16 random train‑test splits per dataset. The evaluation measures (1) inference time and memory consumption, (2) predictive performance (accuracy and macro‑averaged AUC), and (3) order dependence quantified by KL‑divergence between predictions on two random permutations.
Key findings:
- Scalability – Hydra processes inputs up to 2¹⁷ rows (≈130 k) on an H100 GPU before hitting PyTorch’s 32‑bit indexing limit, whereas the Transformer fails beyond 2¹⁶ rows due to the quadratic attention matrix exceeding the 80 GB VRAM. Mamba shows similar scaling to Hydra.
- Predictive performance – Hydra’s average accuracy is within 1.1 % of the Transformer baseline, and its variance across datasets is lower than Mamba’s. In several datasets Hydra even outperforms the Transformer. Mamba exhibits higher variance and a slightly larger performance gap, confirming that bidirectional context is beneficial for tabular inference.
- Order sensitivity – Without RCP, KL‑divergence between predictions on two random permutations is noticeably higher for Mamba than for Hydra, reflecting the latter’s reduced order bias. Applying RCP (r = 5) cuts KL‑divergence substantially for both models and yields a modest accuracy gain (≈0.5 %). AUC does not show a consistent improvement, suggesting that the benefit is primarily in stabilizing class probability estimates.
The authors compare their approach to other methods that alleviate Transformer memory pressure, such as FlashAttention‑2 and linear attention, noting that SSMs attack the root cause (quadratic complexity) rather than merely optimizing the attention kernel.
In conclusion, the study demonstrates that linear‑time SSMs, particularly the bidirectional Hydra, can serve as a practical drop‑in replacement for Transformers in TabPFN, enabling processing of much larger tables while preserving near‑state‑of‑the‑art predictive quality. The RCP technique further enhances robustness to arbitrary row ordering. Future work is suggested on (i) extending experiments to truly large tables (>10 k rows), (ii) exploring optimal row orderings that could further boost SSM performance, and (iii) hybridizing SSMs with advanced attention‑speedup tricks for even greater efficiency.
Comments & Academic Discussion
Loading comments...
Leave a Comment