From $O(mn)$ to $O(r^2)$: Two-Sided Low-Rank Communication for Adam in Distributed Training with Memory Efficiency
As foundation models continue to scale, pretraining increasingly relies on data-parallel distributed optimization, making bandwidth-limited gradient synchronization a key bottleneck. Orthogonally, projection-based low-rank optimizers were mainly designed for memory efficiency, but remain suboptimal for communication-limited training: one-sided synchronization still transmits an $O(rn)$ object for an $m\times n$ matrix gradient and refresh steps can dominate peak communicated bytes. We propose TSR, which brings two-sided low-rank communication to Adam-family updates (TSR-Adam) by synchronizing a compact core $U^\top G V\in\mathbb{R}^{r\times r}$, reducing the dominant per-step payload from $O(mn)$ to $O(r^2)$ while keeping moment states in low-dimensional cores. To further reduce the peak communication from subspace refresh, TSR-Adam adopts a randomized SVD-based refresh that avoids full-gradient synchronization. We additionally extend low-rank communication to embedding gradients with embedding-specific ranks and refresh schedules, yielding additional communication and memory savings over keeping embeddings dense. Across pretraining from 60M to 1B model scales, TSR-Adam reduces average communicated bytes per step by $13\times$, and on GLUE fine-tuning it reduces communication by $25\times$, while achieving comparable performance; we further provide a theoretical stationarity analysis for the proposed update. Code is available at https://github.com/DKmiyan/TSR-Adam.
💡 Research Summary
The paper tackles the communication bottleneck that dominates data‑parallel training of large language models. While recent low‑rank optimizers such as GaLore were designed mainly for memory savings, they still require transmitting an O(r n) (or O(m r)) matrix each step because only one side of the low‑rank factor is synchronized. The authors propose TSR‑Adam, a two‑sided low‑rank communication scheme that reduces the per‑step payload to O(r²) by synchronizing only a compact core matrix C = Uᵀ G V (size r × r) for each weight matrix G ∈ ℝ^{m×n}. Both left and right orthonormal bases U ∈ ℝ^{m×r} and V ∈ ℝ^{n×r} are maintained locally; after each worker computes its local gradient, it projects the gradient onto the bases, forming the core C, which is then averaged across workers via an all‑reduce. The full gradient is reconstructed locally as Ĝ = U C̄ Vᵀ, and AdamW’s first‑ and second‑moment estimates are stored and updated directly in the r × r core space. This preserves the memory benefits of low‑rank optimizers while cutting communication from O(m n) (dense) or O(r n) (one‑sided) down to O(r²).
A major challenge is keeping the bases up‑to‑date without a full‑gradient all‑reduce, which would defeat the purpose. TSR‑Adam addresses this with a randomized SVD refresh performed every K steps. Each worker multiplies its local gradient by a shared random matrix Ω (oversampled by p columns), computes Y = G Ω, orthogonalizes Y to obtain Q, and then computes a small sketch B = Qᵀ G. The sketches Q and B are averaged across workers, and a final SVD on the averaged quantities yields updated bases U and V. Because only the sketches (size O(r p) and O(r²)) are communicated, the peak communication during refresh stays low.
Embedding layers, which often dominate the communication volume due to large vocabularies, are treated specially. The authors assign a separate rank r_emb and refresh interval K_emb to embeddings, applying the same two‑sided core synchronization. This further reduces both communication and optimizer‑state memory for embeddings.
Theoretical analysis shows that the synchronized core C̄ is an unbiased low‑rank approximation of the true averaged gradient, and that Adam updates performed in the core space have the same expected effect as standard Adam on the full gradient, guaranteeing stationarity.
Empirically, TSR‑Adam is evaluated on pre‑training LLaMA models ranging from 60 M to 1 B parameters and on GLUE fine‑tuning. Across all scales, average bytes per step drop by a factor of 13, and on GLUE the reduction reaches 25×, while final loss and downstream performance remain on par with dense AdamW and with GaLore. Memory consumption is also reduced because only r × r moment tensors are stored.
In summary, TSR‑Adam delivers three key contributions: (1) two‑sided low‑rank core synchronization achieving O(r²) communication, (2) a lightweight randomized SVD refresh that controls peak bandwidth, and (3) embedding‑aware low‑rank handling. This communication‑first redesign enables more bandwidth‑efficient distributed training of ever‑larger models without sacrificing convergence or memory efficiency.
Comments & Academic Discussion
Loading comments...
Leave a Comment