DeMo: Decoupled Momentum Optimization
Scaling neural network training increasingly depends on synchronous data-parallelism, yet full-precision gradient all-reduce imposes a severe communication bottleneck. We propose Decoupled Momentum Optimization (DeMo), a drop-in replacement for any momentum-based optimizers that significantly reduces the communication bandwidth while maintaining convergence. DeMo (i) decouples local momentum updates, (ii) applies a fast orthonormal transform (e.g., DCT) followed by top-k sparsification, and (iii) reuses the momentum buffer as error feedback via momentum subtraction. This design reduces per-step communication by up to two orders of magnitude with minimal computational overhead. Experiments on 300M and 1B-parameter DeMo language models show DeMo transmits up to 85x less data per GPU than AdamW-DDP while achieving comparable loss and accuracy. DeMo is topology-agnostic and enables training across multi-datacenter or Ethernet-based setups. Code is available at https://github.com/bloc97/DeMo
💡 Research Summary
The paper introduces Decoupled Momentum Optimization (DeMo), a communication‑efficient framework for distributed data‑parallel training that can replace any momentum‑based optimizer (e.g., SGD with momentum, AdamW, Signum, Muon) without changing its core update rule. The key insight is that the momentum buffer, which aggregates past gradients, is a compressible surrogate for raw gradients. DeMo therefore decouples the synchronization step: instead of all‑reducing dense gradients each iteration, each worker locally updates its momentum buffer Mᵢₜ = β Mᵢ₍ₜ₋₁₎ + (1‑β) Gᵢₜ and communicates a compressed version of this buffer.
Compression proceeds in three stages. First, tensors are partitioned into fixed‑size chunks (e.g., 64 × 64 for matrix weights). Second, each chunk undergoes a separable orthogonal linear projection T(·); the authors primarily use the Discrete Cosine Transform (DCT) because it is fast and can be pre‑computed, though random orthonormal matrices are also evaluated. Third, after projection, only the top‑k coefficients by magnitude are retained (top‑k sparsification). The selected coefficients are sent via an All‑Gather (or All‑Reduce) and averaged across workers. The inverse projection reconstructs a sparse approximation of the global momentum, which is then used for the parameter update. The update rule follows the base optimizer’s transformation ϕ(·) (e.g., identity for SGD, sign for Signum, a matrix‑norm‑based scaling for Muon) and includes weight decay.
Error feedback, a common technique for sparsified training, is integrated without extra memory. After the global momentum is reconstructed, each worker subtracts the reconstructed contribution (scaled by a factor α ∈ (0,1]) from its local buffer. This subtraction stores the “uncommunicated” part of the momentum, ensuring that information omitted by sparsification is not lost but accumulated for future steps. Consequently, DeMo avoids the auxiliary error‑feedback buffers that other methods require.
Theoretical analysis assumes standard stochastic optimization conditions (bounded variance, L‑smoothness, bounded gradients). Under these assumptions, the authors prove that DeMo converges with rate O(1/√T + 1/√N), where T is the number of steps and N the number of workers. The bound shows that the compression factor (k/N) does not degrade asymptotic convergence as long as k grows sublinearly with N.
Empirically, DeMo is evaluated on two decoder‑only transformer language models built with the OLMo framework: a 300 M‑parameter model (OLMo‑300M) and a 1.18 B‑parameter model (OLMo‑1B). Baselines use AdamW with standard DDP. Experiments run on 64 NVIDIA H100 GPUs (global batch 2048, sequence length 2048, gradient accumulation 4). The authors vary the top‑k sparsity budget k ∈ {1,2,4,8,16,32}. Results show:
-
Communication reduction: With chunk size 64 × 64, the per‑step upload bandwidth is reduced by roughly (4096/k). For k = 2, the 300 M model transmits only ~7.5 MB per step versus 637 MB for AdamW‑DDP (≈85× reduction). For the 1 B model, k = 16 yields ~55 MB per step versus 2416 MB (≈44× reduction).
-
Training loss: DeMo with k = 2 already matches or slightly outperforms AdamW in loss curves; increasing k gives diminishing returns.
-
Downstream zero‑shot evaluation: On HellaSwag, ARC‑Easy, and PIQA, DeMo matches or exceeds AdamW‑DDP across both model sizes, confirming that aggressive compression does not harm final model quality.
Ablation studies (not fully detailed in the excerpt) examine the impact of chunk size, choice of projection (DCT vs. random), momentum coefficient β, and subtraction factor α, finding that higher β (≈0.999) and α = 1 improve stability when momentum subtraction is active.
Complexity analysis shows that without chunking, projecting an N × N matrix costs O(N³) and O(N²) memory; chunking into C² blocks reduces these to O(N³/C) and O(N²/C²), respectively. Since all workers share the same projection matrices, the additional memory overhead is negligible.
In summary, DeMo offers a practical, drop‑in replacement for momentum‑based optimizers that dramatically cuts communication bandwidth (up to two orders of magnitude) while preserving convergence speed and final accuracy. Its design—compressing momentum via orthogonal transforms, top‑k sparsification, and reusing the momentum buffer for error feedback—makes large‑scale language model training feasible on commodity Ethernet or geographically distributed clusters, reducing reliance on expensive high‑bandwidth interconnects.
Comments & Academic Discussion
Loading comments...
Leave a Comment