FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling
Attention, as a core layer of the ubiquitous Transformer architecture, is the bottleneck for large language models and long-context applications. While FlashAttention-3 optimized attention for Hopper GPUs through asynchronous execution and warp specialization, it primarily targets the H100 architecture. The AI industry has rapidly transitioned to deploying Blackwell-based systems such as the B200 and GB200, which exhibit fundamentally different performance characteristics due to asymmetric hardware scaling: tensor core throughput doubles while other functional units (shared memory bandwidth, exponential units) scale more slowly or remain unchanged. We develop several techniques to address these shifting bottlenecks on Blackwell GPUs: (1) redesigned pipelines that exploit fully asynchronous MMA operations and larger tile sizes, (2) software-emulated exponential and conditional softmax rescaling that reduces non-matmul operations, and (3) leveraging tensor memory and the 2-CTA MMA mode to reduce shared memory traffic and atomic adds in the backward pass. We demonstrate that our method, FlashAttention-4, achieves up to 1.3$\times$ speedup over cuDNN 9.13 and 2.7$\times$ over Triton on B200 GPUs with BF16, reaching up to 1613 TFLOPs/s (71% utilization). Beyond algorithmic innovations, we implement FlashAttention-4 entirely in CuTe-DSL embedded in Python, achieving 20-30$\times$ faster compile times compared to traditional C++ template-based approaches while maintaining full expressivity.
💡 Research Summary
FlashAttention‑4 addresses the emerging performance imbalance on NVIDIA’s Blackwell‑based datacenter GPUs (B200/GB200), where tensor‑core throughput has doubled while shared‑memory bandwidth, exponential‑unit (MUFU) capacity, and general‑purpose ALU resources have not kept pace. The authors begin with a detailed roofline‑style analysis showing that, for typical attention workloads, the dominant bottlenecks shift from matrix‑multiply (MMA) to shared‑memory traffic and element‑wise exponential operations, which can consume 25‑60 % of execution cycles.
To reclaim the performance headroom, the paper proposes a co‑design of algorithmic changes and kernel implementation that explicitly targets these non‑MMA resources. The key innovations are:
-
Fully asynchronous pipeline with larger tiles – Blackwell’s 128 × 128 (or 256) MMA tiles and the new 256 KB per‑SM tensor memory (TMEM) allow the accumulation results to be written directly to TMEM without register staging. Two warp‑groups per CTA are assigned producer/consumer roles: while one group issues MMA instructions, the other concurrently performs softmax, exponential, and reduction work. This maximizes overlap between compute and memory stages.
-
Software‑emulated exponential via polynomial approximation – Since the MUFU still delivers only 16 exponential ops per cycle, the authors replace hardware exponentials with a low‑degree Chebyshev/FMA‑based polynomial evaluated on the FP‑friendly FMA units. Conditional softmax rescaling further skips unnecessary scaling steps, cutting overall exponential work by roughly 30 %.
-
Tensor‑memory‑backed backward pass and 2‑CTA MMA mode – For the backward pass, intermediate gradients (dQ, dK, dV) are stored in TMEM, halving shared‑memory reads and writes. The 2‑CTA mode lets a pair of CTAs cooperatively execute a single large MMA, each loading half of the B operand into shared memory, thereby reducing shared‑memory bandwidth pressure and halving the number of atomic adds required for the dQ reduction.
-
Fine‑grained scheduling and register allocation – By exploiting Blackwell’s asynchronous capabilities, the scheduler interleaves MMA, exponential, and shared‑memory copy operations at the warp level, achieving ~71 % of the theoretical 2.25 PFLOPs peak (≈ 1613 TFLOPs/s). Register pressure is alleviated because accumulators reside in TMEM, allowing each CTA to run with ~30 % fewer registers than a Hopper‑based implementation.
-
CuTe‑DSL implementation – The entire kernel stack is written in the CuTe domain‑specific language embedded in Python. CuTe abstracts tensor‑core instructions, TMEM management, and 2‑CTA synchronization while preserving low‑level performance. This yields a 20‑30× reduction in compile time compared with traditional C++ template metaprogramming, dramatically improving developer productivity and enabling rapid prototyping of new attention variants.
Experimental evaluation on a B200 GPU (148 SMs, 2.25 PFLOPs BF16) shows:
- Speedups – 1.3× over cuDNN 9.13 and 2.7× over Triton’s best‑available attention kernel in BF16.
- Throughput – Peak 1613 TFLOPs/s, corresponding to 71 % of the hardware’s theoretical maximum.
- Scalability – For long sequences (8K–64K tokens) FlashAttention‑4 outperforms FlashAttention‑3 by 1.8–2.2×, confirming that the redesigned pipeline successfully hides the increased MMA latency behind softmax and exponential work.
The authors open‑source the implementation under a permissive license and plan integration with popular frameworks such as PyTorch and HuggingFace. Future work includes adaptive polynomial degree selection for upcoming Blackwell‑300 GPUs, where the exponential unit throughput doubles, and exploring mixed‑precision strategies that further exploit TMEM’s capacity.
In summary, FlashAttention‑4 demonstrates that careful co‑design of algorithmic structure and kernel pipelines can reclaim performance on GPUs where compute units scale faster than memory and auxiliary functional units. The paper provides a concrete blueprint for future accelerator‑aware deep‑learning kernels in an era of increasingly asymmetric hardware scaling.
Comments & Academic Discussion
Loading comments...
Leave a Comment