SOCKET: SOft Collison Kernel EsTimator for Sparse Attention
Exploiting sparsity during long-context inference is central to scaling large language models, as attention dominates the cost of autoregressive decoding. Sparse attention reduces this cost by restricting computation to a subset of tokens, but its effectiveness depends critically on efficient scoring and selection of relevant tokens at inference time. We revisit Locality-Sensitive Hashing (LSH) as a sparsification primitive and introduce SOCKET, a SOft Collision Kernel EsTimator that replaces hard bucket matches with probabilistic, similarity-aware aggregation. Our key insight is that hard LSH produces discrete collision signals and is therefore poorly suited for ranking. In contrast, soft LSH aggregates graded collision evidence across hash tables, preserving the stability of relative ordering among the true top-$k$ tokens. This transformation elevates LSH from a candidate-generation heuristic to a principled and mathematically grounded scoring kernel for sparse attention. Leveraging this property, SOCKET enables efficient token selection without ad-hoc voting mechanism, and matches or surpasses established sparse attention baselines across multiple long-context benchmarks using diverse set of models. With a custom CUDA kernel for scoring keys and a Flash Decode Triton backend for sparse attention, SOCKET achieves up to 1.5$\times$ higher throughput than FlashAttention, making it an effective tool for long-context inference. Code is open-sourced at https://github.com/amarka8/SOCKET.
💡 Research Summary
The paper introduces SOCKET, a novel sparse‑attention mechanism that replaces traditional hard Locality‑Sensitive Hashing (LSH) with a soft, probabilistic version to rank keys for long‑context inference. In the pre‑fill phase, each key vector is projected with L independent random matrices and assigned to a single bucket per hash table, storing only the bucket IDs and the L2 norm of the corresponding value vectors. During decoding, a query vector is softly hashed: after a tanh transformation, the query’s projection is compared against all possible corner vectors, and a softmax (controlled by temperature τ) yields a probability distribution over buckets for each table. The soft LSH score for a key is the sum across tables of the probability mass assigned to the key’s bucket. Multiplying this score by the value norm produces a final weight b_w, and the top‑k keys by b_w are selected for exact attention computation.
The authors prove (Theorem 3) that the soft LSH kernel closely approximates the softmax kernel, guaranteeing that the selected top‑k tokens capture most of the true attention mass. Implementation-wise, a custom CUDA kernel computes all soft bucket probabilities and b_w scores in a single pass, avoiding full key reads and dramatically reducing memory traffic. The selected keys are then processed by a FlashDecode Triton backend, achieving up to 1.5× higher throughput than FlashAttention while adding only a few bits per token beyond the KV cache.
Extensive experiments on Llama‑3.1‑8B‑Instruct, Qwen3‑8B, and several long‑context benchmarks (RULER‑32K, etc.) show that SOCKET matches or exceeds the ranking quality of prior sparse‑attention methods such as Quest, Double‑Sparsity, PQCache, and MagicPig, while consistently delivering higher throughput and lower memory overhead. The work demonstrates that data‑agnostic random projections combined with soft collision scoring provide a principled, efficient, and theoretically grounded alternative to heuristic or learned sparsification techniques, advancing the state of the art in scalable transformer inference.
Comments & Academic Discussion
Loading comments...
Leave a Comment