WildCat: Near-Linear Attention in Theory and Practice
We introduce WildCat, a high-accuracy, low-cost approach to compressing the attention mechanism in neural networks. While attention is a staple of modern network architectures, it is also notoriously expensive to deploy due to resource requirements that scale quadratically with the input sequence length $n$. WildCat avoids these quadratic costs by only attending over a small weighted coreset. Crucially, we select the coreset using a fast but spectrally-accurate subsampling algorithm – randomly pivoted Cholesky – and weight the elements optimally to minimise reconstruction error. Remarkably, given bounded inputs, WildCat approximates exact attention with super-polynomial $O(n^{-\sqrt{\log(\log(n))}})$ error decay while running in near-linear $O(n^{1+o(1)})$ time. In contrast, prior practical approximations either lack error guarantees or require quadratic runtime to guarantee such high fidelity. We couple this advance with a GPU-optimized PyTorch implementation and a suite of benchmark experiments demonstrating the benefits of WildCat for image generation, image classification, and language model KV cache compression.
💡 Research Summary
The paper introduces WildCat, a novel method for compressing the soft‑max attention mechanism that simultaneously achieves near‑linear computational complexity and super‑polynomial approximation accuracy. Traditional attention requires $Θ(n^2d)$ operations and $Θ(n^2)$ memory, which becomes prohibitive for long sequences. Existing fast approximations—sparse attention (e.g., Reformer), low‑rank kernel methods (e.g., Performer), or hybrid schemes—either lack strong theoretical error guarantees or need quadratic time to obtain high fidelity.
WildCat’s core idea is to replace the full set of $n$ keys with a small weighted coreset of size $r\ll n$. The coreset is selected using a Randomly Pivoted Cholesky (RPC) algorithm, which repeatedly samples a pivot index proportional to the diagonal of the current residual kernel $h^{\text{res}}_r(k_l,k_l)$. After each pivot, the inverse of the kernel matrix on the selected subset is updated via a rank‑one formula, and the residual diagonal is refreshed. This process accesses only $O(nr)$ kernel entries and runs in $O(nr^2+ndr)$ time; when $r=o(n)$ the overall cost becomes $O(n^{1+o(1)})$, i.e., almost linear in the sequence length.
Once the coreset $K_S$ is obtained, WildCat applies the Nyström method to construct optimal weights $W = h(K_S,K_S)^{+}h(K_S,K)$. The resulting low‑rank approximation $\tilde A = h(Q,K_S)W$ approximates the exact attention matrix $A = \exp(\beta QK^\top)$. Lemma 1 shows that if $\tilde A$ is close to $A$ in the row‑wise $2,\infty$ norm, then the final soft‑max output $O$ is close to its exact counterpart, with a bound proportional to $|V|{\max}$. Lemma 2 provides a row‑wise error bound for the Nyström approximation in terms of the operator norm of the residual kernel $h{\text{res}}(K,K)$.
Theoretical guarantees hinge on the spectral approximability of the kernel matrix $H = h(K,K)$. Theorem 1 links the expected operator error of the RPC‑Nyström estimator to any low‑rank matrix $T\preceq H$, showing that $\mathbb{E}|H-\tilde H_r|{\text{op}}\le\varepsilon$ whenever $r$ exceeds the effective rank of $T$ plus a logarithmic factor. By approximating the exponential kernel with its Taylor expansion of order $s$, one obtains a rank‑$O(s+d)$ matrix $T_s$ whose trace error decays as $n\exp(\beta|K|{2,\infty}^2)e^{-\beta|K|_{2,\infty}^2(s+1)}$. Consequently, choosing $r = O(\log n)$ yields an overall attention error of $O!\bigl(n^{-\sqrt{\log\log n}}\bigr)$—a super‑polynomial decay—while retaining near‑linear runtime.
Implementation details include recentring keys (subtracting the mean) and scaling queries/keys by $\beta=1/\sqrt d$ to improve numerical stability. The algorithm computes $h(K_S,K)$ and $h(Q,K_S)$ in a single batched matrix multiplication on GPU, then forms the final output as $D^{-1}UWV$, where $D$ is the diagonal normalizer. Memory usage is $O((m+n)(r+d))$, far lower than the $O(mn)$ required by exact attention.
Empirical evaluation spans three domains: (1) image generation using a Stable Diffusion backbone, where WildCat speeds up sampling by 1.8× and improves FID by 0.12; (2) image classification on CIFAR‑10/100 and ImageNet‑1k, achieving 12–18 % faster training with <0.3 % accuracy loss compared to Reformer, Performer, FlashAttention, and other linear‑time baselines; (3) long‑context language modeling, where WildCat compresses the KV cache, reducing memory consumption by over 30 % while increasing perplexity by less than 0.2 %. All experiments run on NVIDIA A100 GPUs with identical batch sizes, demonstrating consistent speed‑accuracy trade‑offs.
The paper also discusses limitations. Selecting the coreset size $r$ currently requires empirical tuning; automatic, data‑driven strategies are an open problem. For very high‑dimensional inputs, the $O(nrd)$ cost may become a bottleneck on memory‑bandwidth limited hardware. The stochastic nature of RPC introduces variance; worst‑case performance could degrade, though the expected error bounds hold. Finally, the theoretical results assume bounded inputs ($|Q|{2,\infty},|K|{2,\infty}\le C$), so additional normalization may be needed for unbounded data.
In summary, WildCat bridges the gap between theory and practice in fast attention: it provides the first practically implementable algorithm that attains super‑polynomial error decay with near‑linear time, backed by rigorous proofs and extensive GPU‑optimized experiments. This makes it a compelling choice for large‑scale vision and language models where both speed and fidelity are critical.
Comments & Academic Discussion
Loading comments...
Leave a Comment