TiledAttention: a CUDA Tile SDPA Kernel for PyTorch
TiledAttention is a scaled dot-product attention (SDPA) forward operator for SDPA research on NVIDIA GPUs. Implemented in cuTile Python (TileIR) and exposed as a PyTorch-callable function, it is easier to modify than low-level CUDA templates while retaining realistic behavior via online softmax and tiled $K,V$ streaming. The approach is both performant and directly editable at the schedule level from Python (tile shapes, staging, shared-memory layout), enabling rapid, reproducible kernel research without template-heavy CUDA/CUTLASS rewrites. We benchmark TiledAttention on an NVIDIA DGX GB10 node with a reproducible harness and compare against PyTorch SDPA (auto-dispatch) and explicit unfused baselines across sequence length, head dimension, and precision (FP16/BF16). While production fused baselines remain stronger overall, TiledAttention delivers large speedups over standard eager attention paths and is available for direct use within PyTorch workflows, providing a practical balance between performance and customizability.
💡 Research Summary
TiledAttention introduces a novel forward implementation of scaled dot‑product attention (SDPA) that is both high‑performance and highly modifiable, targeting NVIDIA GPUs commonly used in European HPC deployments. The kernel is written in cuTile Python (TileIR), a high‑level tile‑programming DSL that compiles to CUDA tile IR, allowing researchers to expose and adjust schedule parameters—tile dimensions, staging depth, and shared‑memory layout—directly from Python without diving into low‑level CUDA or CUTLASS templates.
The algorithm follows the “online‑softmax” approach popularized by FlashAttention: the sequence dimension S is partitioned into tiles; each cooperative thread array (CTA) owns a block of query rows (Q‑tile) and streams corresponding K and V tiles. For each streamed tile the kernel computes partial dot‑product scores, applies a mask, and updates running softmax statistics (maximum m_i, normalizer ℓ_i) and the output accumulator o_i using FP32 accumulation for numerical stability. After all tiles have been streamed, the final output O_i is obtained by normalizing o_i with ℓ_i. No full S×S score matrix is materialized, which dramatically reduces memory pressure for long sequences.
Key tunable parameters are TM (rows per Q‑tile) and TN (columns per K/V tile). The authors explore typical values TM ∈ {64,128} and TN ∈ {64,128,256}. A systematic sweep shows that the optimal pair depends on head dimension D and sequence length S; for example, (TM, TN) = (64, 128) is best for S = 1024, D = 128, while using (64, 64) incurs a 3.9 % slowdown. Staging depth and shared‑memory swizzling further affect bank conflicts and latency hiding.
Benchmarking is performed on an NVIDIA DGX GB10 node (8 × A100 GPUs) with a reproducible harness: explicit warm‑up runs, CUDA event timing, median and 95th‑percentile reporting, correctness checks against high‑precision reference for small shapes, and isolation of GPU streams. The workload grid covers sequence lengths S ∈ {512, 1024, 2048, 4096, 8192}, head dimensions D ∈ {64, 96, 128, 160}, and both FP16 and BF16 dtypes, with causal and non‑causal masking.
Results show two distinct regimes. For short sequences (S ≤ 2048) the overhead of kernel launch and reduction dominates; in this region TiledAttention can match or slightly exceed the fused PyTorch SDPA baseline, achieving up to 0.766× the fused throughput in a shape‑aware configuration. For long sequences (S ≥ 4096) memory bandwidth and shared‑memory pressure become limiting, and the fused PyTorch implementation remains faster, with TiledAttention achieving on average 0.632× (median 0.634×) of fused throughput across all 80 configurations. The gap is smallest for D = 128 (0.947×), indicating that the tile sizes chosen align well with this head dimension.
When compared against the unfused “math” SDPA path and a naïve eager‑attention implementation, TiledAttention delivers dramatic speedups: 28.15× faster than math‑SDPA and 14.36× faster than eager attention on average. Profiling with Nsight Compute and Nsight Systems reveals that for long S the kernel is memory‑bound, with shared‑memory staging depth and layout being the primary levers for improvement. For short S, insufficient parallelism limits performance, and larger TM values can increase occupancy at the cost of register pressure.
The authors discuss practical adoption guidance. For production workloads requiring maximum out‑of‑the‑box throughput across a wide range of shapes, the fused PyTorch SDPA remains the recommended path. However, for research, rapid kernel iteration, custom masking, or layout experiments, TiledAttention offers a clear advantage: schedule changes are a few lines of Python, recompilation is JIT‑cached, and profiling is reproducible.
Limitations include the current focus on the forward pass only; backward pass, KV‑cache support, and fused epilogues are left for future work. Portability beyond A100‑class GPUs may require additional schedule variants. The tuning space explored is deliberately modest; larger automated searches could uncover further gains.
In conclusion, TiledAttention demonstrates that a tile‑programming DSL can deliver a practical, modifiable SDPA kernel that bridges the gap between research flexibility and realistic performance on modern GPUs. While it does not surpass production‑grade fused kernels for the longest contexts, it provides substantial speedups over baseline eager implementations and enables fast, reproducible kernel research within the PyTorch ecosystem.
Comments & Academic Discussion
Loading comments...
Leave a Comment