DASH: Deterministic Attention Scheduling for High-throughput Reproducible LLM Training
Determinism is indispensable for reproducibility in large language model (LLM) training, yet it often exacts a steep performance cost. In widely used attention implementations such as FlashAttention-3, the deterministic backward pass can incur up to a 37.9% throughput reduction relative to its non-deterministic counterpart, primarily because gradient accumulation operations must be serialized to guarantee numerical consistency. This performance loss stems from suboptimal scheduling of compute and gradient-reduction phases, leading to significant hardware underutilization. To address this challenge, we formulate the backward pass of deterministic attention as a scheduling problem on a Directed Acyclic Graph (DAG) and derive schedules that minimize the critical path length. Building on this formulation, we present DASH (Deterministic Attention Scheduling for High-Throughput), which encapsulates two complementary scheduling strategies: (i) Descending Q-Tile Iteration, a reversed query-block traversal that shrinks pipeline stalls in causal attention, and (ii) Shift Scheduling, a theoretically optimal schedule within our DAG model that reduces pipeline stalls for both full and causal masks. Our empirical evaluations on NVIDIA H800 GPUs demonstrate that DASH narrows the performance gap of deterministic attention. The proposed strategies improve the throughput of the attention backward pass by up to 1.28$\times$ compared to the baseline, significantly advancing the efficiency of reproducible LLM training. Our code is open-sourced at https://github.com/SJTU-Liquid/deterministic-FA3.
💡 Research Summary
**
The paper tackles a critical bottleneck in deterministic back‑propagation for large language model (LLM) training, specifically the up‑to‑37.9 % throughput loss observed in FlashAttention‑3 when deterministic mode is enabled. Deterministic execution is required for bit‑wise reproducibility across thousands of GPUs, but the current approach forces a serialized global reduction of the query‑gradient (dQ) across streaming multiprocessors (SMs). This serialization creates pipeline bubbles, especially under causal masking where dependencies force each SM to wait for the previous one, inflating the critical path to n·(c + r) + (n‑1)·r.
The authors formalize the backward pass as a directed acyclic graph (DAG). Each tile (i, j) consists of a compute node C(i,j) followed by a reduction node R(i,j). Edges within a KV‑tile are weighted by constant compute (c) and reduction (r) times, while zero‑weight edges encode the required ordering for deterministic accumulation. The optimization objective is to minimize the DAG’s critical‑path length, which directly reduces overall latency.
Two complementary scheduling strategies are introduced:
-
Descending Q‑Tile Iteration – For causal masks, the query tiles are processed in reverse order. By completing the “short” tasks first, SMs become free earlier, allowing subsequent tiles to start their reductions without waiting. The resulting execution time approximates T_rev ≈ m·(n+1)(c+r)/2 + (n‑1)·r, substantially shrinking idle periods.
-
Shift Scheduling – This strategy provides a provably optimal schedule under the DAG model. It cyclically shifts the assignment of query tiles to SMs (e.g., SM0 processes Q0→Q1→…→Qn‑1, SM1 processes Q1→Q2→…→Q0, etc.). The cyclic shift creates a conflict‑free reduction sequence where each SM’s reduction phase follows the previous one without overlap, balancing the workload perfectly. The critical path collapses to 2c + 2r, the theoretical minimum.
Experiments on NVIDIA H800 GPUs (8 SMs) evaluate both full‑mask and causal‑mask scenarios. Compared with the baseline deterministic FlashAttention‑3, Descending Q‑Tile Iteration yields a 1.12× speedup on full‑mask and 1.18× on causal‑mask. Shift Scheduling achieves 1.15× and 1.28× speedups respectively, effectively eliminating pipeline bubbles in the causal case. The gains scale linearly with the number of attention heads (m) and SMs (n), demonstrating that the approaches remain effective for very large models.
Key contributions include: (i) identifying the misalignment between tile execution and deterministic accumulation order as the root cause of performance loss; (ii) providing the first DAG‑based formalization of deterministic attention backward scheduling; (iii) introducing two practical, complementary scheduling algorithms that together close the performance gap by up to 28 % without sacrificing reproducibility; and (iv) open‑sourcing the implementation for the community.
Overall, DASH shows that deterministic LLM training need not incur prohibitive overheads; by rethinking the execution schedule at the DAG level, one can achieve near‑optimal hardware utilization while preserving exact reproducibility.
Comments & Academic Discussion
Loading comments...
Leave a Comment