DeMo: 분산 학습을 위한 압축 모멘텀 최적화
초록
DeMo는 기존 모멘텀 기반 옵티마이저를 그대로 사용하면서 로컬 모멘텀을 변환·압축하여 통신량을 크게 줄이는 방법이다. DCT와 같은 직교 변환 후 top‑k 스파시피케이션을 적용하고, 남은 정보를 모멘텀 버퍼에 그대로 저장해 오류 피드백을 구현한다. 실험 결과 300M·1B 규모 언어 모델에서 AdamW‑DDP 대비 85배까지 전송량을 감소시키면서도 손실·정확도는 동등하거나 약간 향상된다.
상세 분석
DeMo는 분산 데이터 병렬 학습에서 가장 큰 병목인 전체 정밀도 그래디언트 All‑Reduce를 회피하기 위해 “모멘텀 자체가 압축 가능한 정보원”이라는 핵심 가정을 세운다. 기존 DDP 파이프라인에서는 미니배치마다 계산된 그래디언트를 즉시 동기화하지만, DeMo는 각 워커가 로컬 모멘텀 버퍼 Mᵢₜ 를 β Mᵢ₍ₜ₋₁₎ + (1‑β) Gᵢₜ 로 업데이트하고, 이 버퍼를 바로 통신 대상으로 삼는다.
통신 효율을 높이기 위해 두 단계의 구조적 압축을 도입한다. 첫째, 텐서를 고정된 크기의 청크로 분할하고, 각 청크에 대해 다차원 직교 변환 T(·) 을 적용한다. 논문에서는 빠른 구현이 가능한 DCT와 무작위 직교 행렬을 사용했으며, DCT가 메모리와 연산 측면에서 가장 효율적이었다. 둘째, 변환된 계수 중 절댓값이 큰 k 개만 선택해 top‑k 스파시피케이션을 수행한다. 이렇게 하면 청크당 전송량이 (k₀·k₁)/k 배로 감소한다.
압축된 계수 ĤQₖ 는 All‑Gather 후 평균을 취해 복원된 모멘텀 M*ₜ 를 얻는다. 복원 과정에서는 역변환 IDCT 또는 T⁻¹ 을 사용한다. 중요한 점은 복원된 모멘텀을 바로 파라미터 업데이트에 활용한다는 것이다. SGD‑M, Signum, Muon 등 다양한 기본 옵티마이저에 맞춰 ϕ(M) 함수를 정의해 동일한 프레임워크 안에서 적용한다.
오류 피드백은 별도의 버퍼를 두지 않고 모멘텀 버퍼 자체를 활용한다. 통신 후 복원된 모멘텀을 α 비율만큼 현재 모멘텀에서 빼는 Mᵢₜ ← Mᵢₜ − α · T⁻¹(·) 연산을 수행함으로써, 아직 전송되지 않은 정보가 버퍼에 남아 다음 스텝에서 보강된다. 이는 기존 오류 피드백 기법이 요구하던 추가 메모리 오버헤드를 완전히 제거한다.
이론적 분석에서는 표준 가정(분산된 무작위 샘플링, L‑smooth, 유한 분산, 유계 그래디언트) 하에 수렴률 O(1/√T + 1/√N) 을 증명한다. 여기서 T 는 총 스텝 수, N 은 워커 수이며, 압축 비율 k/N 이 충분히 작아도 수렴에 큰 영향을 주지 않음을 보인다.
실험에서는 OLMo 프레임워크 위에 300 M 및 1 B 파라미터 규모의 디코더‑전용 트랜스포머를 학습시켰다. 기본 AdamW‑DDP와 비교했을 때, k = 2 ~ 8 정도의 작은 top‑k 값만으로도 전송량을 85배(300 M)·44배(1 B)까지 감소시키면서 손실 곡선과 사전학습 후 HellaSwag, ARC‑Easy, PIQA 등 zero‑shot 벤치마크에서 동등하거나 약간 높은 정확도를 달성했다. 특히 k = 1 ~ 2 수준에서도 거의 동일한 수렴 속도를 보였으며, k를 늘릴수록 통신량은 증가하지만 정확도 향상은 미미했다.
복잡도 측면에서 청크 기반 변환은 O(N³/C) 연산으로 전체 비용을 선형적으로 감소시키고, 메모리 사용량도 O(N²/C²)로 크게 줄인다. 변환 행렬은 모든 워커가 공유하므로 추가 메모리는 무시할 수준이다.
요약하면, DeMo는 (1) 모멘텀을 직접 압축 대상로 삼고, (2) 빠른 직교 변환 + top‑k 스파시피케이션으로 통신량을 극단적으로 감소시키며, (3) 모멘텀 버퍼 자체를 오류 피드백 메커니즘으로 활용해 구현 복잡도와 메모리 오버헤드를 최소화한다. 이는 대규모 언어 모델을 고대역폭 인터커넥트가 없는 데이터센터나 멀티‑데이터센터 환경에서도 효율적으로 학습할 수 있게 만든다.
댓글 및 학술 토론
Loading comments...
의견 남기기