BLASST: Dynamic BLocked Attention Sparsity via Softmax Thresholding
The growing demand for long-context inference capabilities in Large Language Models (LLMs) has intensified the computational and memory bottlenecks inherent to the standard attention mechanism. To address this challenge, we introduce BLASST, a drop-in sparse attention method that dynamically prunes the attention matrix without any pre-computation or proxy scores. Our method uses a fixed threshold and existing information from online softmax to identify negligible attention scores, skipping softmax computation, Value block loading, and the subsequent matrix multiplication. This fits seamlessly into existing FlashAttention kernel designs with negligible latency overhead. The approach is applicable to both prefill and decode stages across all attention variants (MHA, GQA, MQA, and MLA), providing a unified solution for accelerating long-context inference. We develop an automated calibration procedure that reveals a simple inverse relationship between optimal threshold and context length, enabling robust deployment across diverse scenarios. Maintaining high accuracy, we demonstrate a 1.62x speedup for prefill at 74.7% sparsity and a 1.48x speedup for decode at 73.2% sparsity on modern GPUs. Furthermore, we explore sparsity-aware training as a natural extension, showing that models can be trained to be inherently more robust to sparse attention patterns, pushing the accuracy-sparsity frontier even further.
💡 Research Summary
The paper introduces BLASST (Dynamic Blocked Attention Sparsity via Softmax Thresholding), a lightweight, drop‑in modification to FlashAttention that dynamically prunes attention blocks during the forward pass without any pre‑computation, proxy scores, or architectural changes. The key observation is that FlashAttention already computes a running maximum of attention scores while processing blocks. If the local maximum of a candidate block is lower than the running maximum by more than a fixed logarithmic threshold (ln λ), the exponential and softmax normalization for that block will produce values close to zero. Consequently, BLASST skips three expensive operations for such blocks: (1) the exp (·) computation and row‑sum reduction, (2) loading the corresponding value (V) block from high‑bandwidth memory, and (3) the matrix multiplication of the softmax weights with the V block. This decision requires only a single comparison per block and reuses statistics already computed for the block’s row‑max, adding virtually no overhead.
Algorithm 1 shows the integration: after computing the block’s attention scores Sᵢⱼ = QᵢKⱼᵀ, the algorithm computes the block’s local max ˜mᵢⱼ, updates the running max mᵢⱼ, and checks whether ˜mᵢⱼ − mᵢⱼ < ln λ. If true, the block is omitted; otherwise the usual softmax weight computation, accumulation, and output update proceed. Because the decision is made at block granularity, the method works seamlessly with all common attention variants (multi‑head, grouped‑query, multi‑query, multi‑layer‑aggregation) and with sliding‑window attention.
To make the method practical across varying sequence lengths, the authors conduct an extensive calibration study on Llama‑3.1‑8B using the RULER benchmark across context lengths from 8 K to 64 K tokens. They find that accuracy degradation is primarily a function of sparsity ratio, not of dataset or length: performance remains stable up to roughly 60‑70 % sparsity, after which it drops sharply. However, achieving a fixed sparsity requires different λ values for different lengths; empirically λ follows an inverse relationship with context length: λ = a / L, where a is a model‑specific constant. Algorithm 2 automates the calibration: for a target sparsity S, it searches λ that yields S for several lengths, fits a linear regression on (1/L, λ) to obtain a, and then uses λ(L) = a/L at inference time. This ensures predictable sparsity and thus predictable speedup across deployments.
Performance results on modern GPUs (NVIDIA H200, B200) show up to 1.62× speedup for the prefill phase at 74.7 % sparsity and 1.48× speedup for the decode phase at 73.2 % sparsity, with less than 1 % absolute accuracy loss. The prefill kernel benefits mainly from reduced CUDA core and Tensor Core usage (compute‑bound), while the decode kernel gains from lower memory bandwidth consumption (memory‑bound) by skipping V‑block loads.
Beyond inference, the authors explore sparsity‑aware training: during training they apply the same block‑masking rule (with a fixed λ) so the model learns to be robust to the sparsity pattern. This training regime pushes the accuracy‑sparsity frontier further, allowing higher sparsities (≈80 %) with minimal degradation, demonstrating that BLASST can be used both as a post‑training optimizer and as a training‑time regularizer.
In summary, BLASST contributes: (1) a zero‑overhead, online sparsity decision based solely on running maxima, (2) an automated calibration method that maps context length to optimal λ, (3) specialized CUDA kernels for both prefill and decode that achieve substantial real‑world speedups, and (4) a sparsity‑aware training extension that further improves robustness. The method is hardware‑friendly, requires no model re‑architecting, and can be deployed immediately to accelerate long‑context LLM inference in production settings.
Comments & Academic Discussion
Loading comments...
Leave a Comment