AXLearn: Modular, Hardware-Agnostic Large Model Training
AXLearn is a production system which facilitates scalable and high-performance training of large deep learning models. Compared to other state-of-art deep learning systems, AXLearn has a unique focus on modularity and support for hardware-agnostic training. AXLearn’s internal interfaces between software components follow strict encapsulation, allowing different components to be assembled to facilitate rapid model development and experimentation on different hardware infrastructure. AXLearn maintains constant complexity as we scale the components in the system, compared to linear or quadratic complexity in state-of-the-art training systems. This allows integrating features such as Rotary Position Embeddings (RoPE) into AXLearn across hundred of modules with just 10 lines of code, compared to hundreds as required in other systems. At the same time, AXLearn maintains equivalent performance compared to state-of-the-art training systems. Finally, we share our experience in the development and operation of AXLearn at Apple.
💡 Research Summary
AXLearn is a production‑grade framework designed to train large deep‑learning models with a focus on two often competing goals: extreme modularity and hardware‑agnostic execution. The authors begin by observing that modern AI services (e.g., ChatGPT, Gemini, large‑scale video‑conferencing tools) require rapid experimentation on ever‑growing model families, while large technology companies such as Apple cannot rely on a single accelerator due to supply constraints, cost considerations, and the need to run workloads on public clouds as well as on‑premise hardware. Existing large‑model training systems (Megatron‑LM, DeepSpeed, PyTorch FSDP, etc.) typically achieve performance through aggressive GPU‑centric optimizations and by exposing a “flat” configuration that intertwines model definition with parallelism strategy. Moreover, they rely heavily on sub‑typing: a new layer is introduced by inheriting from a base class, which forces changes to propagate up the inheritance hierarchy. This leads to a linear or even quadratic growth in code‑change size (LoC) when adding features such as Rotary Position Embeddings (RoPE) or Mixture‑of‑Experts (MoE).
AXLearn tackles these problems by building on top of JAX/XLA and GSPMD, thereby inheriting a hardware‑agnostic compilation model that can target GPUs, TPUs, and Apple’s Trainium2. The framework enforces strict encapsulation: each component (input pipeline, model, optimizer, checkpointing, trainer loop, etc.) is a self‑contained module with a well‑defined interface. Modules are configured through a hierarchical “Config” object rather than a monolithic flat file. This hierarchy mirrors the logical tree of a neural network, allowing child modules to declare only the parameters they need while parents propagate dimensions and other contextual information downstream.
The key technical contribution is a formal metric for extensibility: the asymptotic LoC change required to add a new feature. For a sub‑typing‑based system, the authors prove that the LoC‑complexity of integrating MoE is lower‑bounded by O(N), where N is the number of modules that must be retuned. In AXLearn, because MoE and FFN share the same input/output contract, the same 10‑line snippet can replace any FFN across a thousand experiments, yielding O(1) complexity. Empirical counts confirm this: DeepSpeed would need >4,000 lines to retrofit MoE across Apple’s internal model zoo, whereas AXLearn needs only a handful of lines.
The “Composer” component takes a user‑written script that builds the hierarchical Config, materializes a full JAX program, automatically selects an appropriate mesh shape, annotates sharding, tunes XLA compilation flags, and inserts hardware‑specific kernels (e.g., FlashAttention). The resulting program is handed to the XLA compiler, which emits device‑specific binaries (CUDA kernels, TPU kernels, etc.). The “Runtime” then orchestrates execution on a Kubernetes cluster, providing fault‑tolerant checkpointing, monitoring, and dynamic scaling. Because the Composer and Runtime are cleanly separated, swapping out the runtime (e.g., moving from a GPU‑based cluster to a Trainium2 cluster) requires no changes to model definitions.
Performance experiments demonstrate parity with state‑of‑the‑art systems on identical hardware configurations. AXLearn matches Megatron‑LM and DeepSpeed in throughput and memory efficiency while offering dramatically simpler code paths for new features. The authors also report large‑scale internal adoption: thousands of models trained by hundreds of engineers across Google Cloud TPUs, AWS GPUs, and Apple’s own Trainium2 clusters. The modular design has accelerated the onboarding of new research ideas and reduced the time‑to‑experiment from weeks to days.
Finally, AXLearn is released under the Apache 2.0 license, inviting the broader community to extend its modular components, contribute new hardware back‑ends, or integrate novel parallelism strategies. The paper concludes that a design grounded in strict encapsulation, hierarchical configuration, and hardware‑agnostic compilation can simultaneously achieve high performance, scalability, and maintainability—attributes essential for the next generation of AI research and production systems.
Comments & Academic Discussion
Loading comments...
Leave a Comment