LoRDO: 통신은 적게, 성능은 그대로, 분산 저랭크 최적화의 새로운 기준
초록
분산 학습의 통신 병목 현상을 해결하기 위해, LoRDO는 저랭크 최적화와 드문 동기화를 결합한 프레임워크입니다. 글로벌 프로젝션으로 안정적인 저차원 공간을 유지하면서도, 풀랭크 쿼시-하이퍼볼릭 업데이트를 도입해 공간 탐색을 복원합니다. 125M~720M 규모의 언어 모델에서 기존 저랭크 DDP 대비 통신량을 약 10분의 1로 줄이면서도 유사한 성능을 달성했으며, 특히 메모리가 제한된 저랭크/소배치 설정에서 더 큰 성능 향상을 보입니다.
상세 분석
본 논문이 제안하는 LoRDO 프레임워크는 분산 학습의 두 가지 핵심 과제인 통신 오버헤드와 옵티마이저 상태의 메모리 부담을 동시에 해결하고자 합니다. 기존 드문 통신 전략은 동기화 빈도를 줄이지만, 여전히 옵티마이저 상태의 통신 및 저장에 대한 부담이 남아 있었습니다. 한편, GaLore나 LDAdam과 같은 저랭크 옵티마이저는 메모리와 통신 부담을 줄일 수 있지만, 각 워커가 전체 배치 그래디언트에 접근할 수 없는 드문 동기화 환경에서는 프로젝션 노이즈가 커져 성능이 저하되는 문제가 있습니다.
LoRDO의 핵심 기여는 다음과 같은 설계 선택에 있습니다. 첫째, ‘글로벌 프로젝션’ 전략을 채택합니다. 각 워커가 지역적 그래디언트로 개별적으로 프로젝션 행렬을 계산하는 대신, 동기화 시점에 모든 워커의 변화를 집계한 ‘의사 그래디언트’로부터 공통의 프로젝션 행렬을 계산합니다. 이는 효과적 배치 크기를 늘려 프로젝션의 안정성을 높이고, 모든 워커가 동일한 저랭크 부분공간에서 최적화를 진행하도록 보장합니다. 그러나 이 방법은 최적화 궤적이 해당 저랭크 부분공간에 영구적으로 제한되는 ‘정체’ 문제를 야기합니다.
둘째, 이 정체 문제를 해결하기 위해 ‘풀랭크 쿼시-하이퍼볼릭 모멘텀’ 업데이트를 도입합니다. 저랭크로 투영된 모멘텀 업데이트에, 원본 풀랭크 그래디언트 신호를 적절히 스케일링하여 더하는 방식입니다. 이는 의사 그래디언트에 풀랭크 성분을 주입함으로써, 전체 매개변수 공간에 대한 탐색을 가능하게 하고 글로벌 프로젝션의 이론적 장점을 유지하면서도 정체를 방지합니다.
또한, LoRDO는 새로운 프로젝션 행렬 계산 시 모멘텀 상태의 회전 정렬과 지역 최적화 과정의 성능 향상을 위한 오류 피드백 메커니즘을 통합했습니다. 실험 결과, LoRDO는 125M에서 720M 파라미터 규모의 언어 모델링 및 다운스트림 작업에서 저랭크 DDP와 거의 동등한 성능(퍼플렉서티 격차 <1%)을 유지하면서도 통신량을 약 10배 가량 줄였습니다. 더욱이 메모리가 극도로 제한되어 낮은 랭크와 작은 배치 크기를 사용해야 하는 설정에서는 오히려 DDP 대비 3.36%~4.7%의 성능 향상을 보여, 극한의 제약 조건에서의 효용성을 입증했습니다.
댓글 및 학술 토론
Loading comments...
의견 남기기