모듈형·하드웨어 무관 대규모 모델 학습 프레임워크 AXLearn
초록
AXLearn은 엄격한 캡슐화를 기반으로 모듈성을 극대화하고, XLA와 GSPMD 위에 구축해 GPU·TPU·Trainium 등 다양한 가속기에 동일한 성능을 제공하는 대규모 딥러닝 훈련 시스템이다. 기존 프레임워크가 서브타이핑에 의존해 코드 복잡도가 선형·이차적으로 증가하는 반면, AXLearn은 계층적 설정과 구성 트리를 이용해 새로운 기능(예: RoPE, MoE)을 O(1) 수준의 코드 변경으로 적용한다. 실험 결과, 동일한 하드웨어에서 Megatron‑LM, DeepSpeed 등과 비교해 학습 속도·자원 효율성에서 차이가 없으며, 수천 명의 엔지니어가 수천 개 모델을 빠르게 프로토타이핑하고 배포할 수 있었다.
상세 분석
AXLearn의 핵심 설계 철학은 “엄격한 캡슐화와 모듈 교체 가능성”이다. 기존 딥러닝 프레임워크는 레이어를 베이스 클래스로부터 상속받아 구현하고, 새로운 기능을 추가하려면 상위 클래스까지 수정해야 하는 서브타이핑 방식을 채택한다. 이 방식은 기능 추가 시 전체 계층 구조에 걸쳐 코드가 퍼지는 O(N) 복잡도를 초래한다. 논문에서는 DeepSpeed에 MoE를 적용할 때 200줄 이상의 변경이 필요함을 예시로 들며, 이러한 복잡도가 실제 대규모 프로젝트에서 수천 줄로 확대될 수 있음을 강조한다.
AXLearn은 레이어와 모듈을 “입출력 인터페이스만 일치하면 교체 가능”하도록 설계했다. MoE와 같은 새로운 레이어는 기존 FFN과 동일한 입력·출력 형태만 맞추면 Transformer 레이어 내부에서 단순히 교체될 수 있다. 이를 가능하게 하는 것이 계층적 Config 객체와 “Config Modifier” 함수이다. Config는 각 모듈의 파라미터를 캡슐화하고, 부모 모듈이 자식 모듈에 필요한 정보를 전달하면서도 자식이 독립적으로 정의될 수 있게 한다. 따라서 새로운 기능을 적용하려면 10줄 내외의 스크립트만 추가하면 된다.
하드웨어 무관성을 확보하기 위해 AXLearn은 JAX/XLA와 GSPMD를 기반으로 한다. XLA는 다양한 백엔드(GPU, TPU, Trainium)에서 동일한 연산 그래프를 컴파일하지만, 최적 성능을 위해서는 “힌트” 제공, 커스텀 커널(예: FlashAttention) 적용, 리머티리얼라이제이션 전략 조정이 필요하다. AXLearn은 이러한 최적화 과정을 자동화하는 “Composer” 단계에서 메쉬 셰이핑, 셰어링 어노테이션, 컴파일 옵션 튜닝 등을 수행한다. 결과적으로 엔지니어는 하드웨어 선택에 구애받지 않고 동일한 파이프라인을 재사용할 수 있다.
시스템 구조는 크게 두 부분으로 나뉜다. (1) AXLearn Composer는 사용자 정의 Config를 기반으로 전체 JAX 프로그램을 생성하고, 하드웨어별 최적화를 적용한다. (2) AXLearn Runtime은 생성된 프로그램을 쿠버네티스 기반 클러스터에 배포하고, 체크포인팅, 모니터링, 장애 복구 등을 담당한다. 이 두 층은 명확히 분리돼 있어, 런타임을 교체하거나 새로운 클라우드 제공자를 추가해도 기존 모델 정의는 그대로 유지된다.
실험 결과, AXLearn은 Megatron‑LM, DeepSpeed, PyTorch FSDP 등과 비교해 동일한 모델·배치·하드웨어 조건에서 학습 속도와 메모리 사용량에서 차이가 없으며, 특히 MoE와 RoPE 같은 최신 기법을 적용할 때 코드 변경량이 현저히 적다. 또한 Apple 내부에서 수천 개 모델을 수천 명 엔지니어가 사용하면서, 모듈 교체와 하드웨어 전환이 원활히 이루어졌다는 실사용 사례를 제시한다.
요약하면, AXLearn은 “복잡도 O(1)”, “하드웨어 무관”, “계층적 Config 기반”이라는 세 축을 통해 대규모 모델 훈련의 생산성을 크게 향상시킨다. 이는 앞으로 다양한 하드웨어와 모델 아키텍처가 급변하는 AI 환경에서 확장성과 유지보수성을 동시에 만족시키는 중요한 설계 패러다임을 제시한다.
댓글 및 학술 토론
Loading comments...
의견 남기기