Mixture-of-World Models: Scaling Multi-Task Reinforcement Learning with Modular Latent Dynamics
A fundamental challenge in multi-task reinforcement learning (MTRL) is achieving sample efficiency in visual domains where tasks exhibit substantial heterogeneity in both observations and dynamics. Model-based reinforcement learning offers a promising path to improved sample efficiency through world models, but standard monolithic architectures struggle to capture diverse task dynamics, resulting in poor reconstruction and prediction accuracy. We introduce Mixture-of-World Models (MoW), a scalable architecture that combines modular variational autoencoders for task-adaptive visual compression, a hybrid Transformer-based dynamics model with task-conditioned experts and a shared backbone, and a gradient-based task clustering strategy for efficient parameter allocation. On the Atari 100k benchmark, a single MoW agent trained once on 26 Atari games achieves a mean human-normalized score of 110.4%, competitive with the score of 114.2% achieved by STORM, an ensemble of 26 task-specific models, while using 50% fewer parameters. On Meta-World, MoW achieves a 74.5% average success rate within 300 thousand environment steps, establishing a new state of the art. These results demonstrate that MoW provides a scalable and parameter-efficient foundation for generalist world models.
💡 Research Summary
Mixture‑of‑World Models (MoW) tackles the longstanding challenge of sample‑efficient multi‑task reinforcement learning (MTRL) in visual domains, where tasks differ dramatically in both appearance and dynamics. Traditional model‑based RL (MBRL) approaches rely on a single world model that must simultaneously learn to reconstruct high‑dimensional observations and predict future states, a requirement that quickly overwhelms monolithic architectures when faced with heterogeneous tasks. MoW resolves this by introducing a modular architecture composed of three tightly integrated components: (1) task‑specific variational auto‑encoders (VAEs) for visual compression, (2) a hybrid Transformer‑based dynamics core that mixes a shared backbone with a set of expert Transformers selected via task‑level routing, and (3) a gradient‑based task clustering procedure that determines which tasks share which VAE and auxiliary predictor modules.
Perceptual module – Each task k is assigned a dedicated categorical VAE (encoder q_{ϕ,ik} and decoder p_{ϕ,ik}) that maps raw pixels into a stochastic latent code zₜᵏ drawn from a 32‑by‑32 discrete distribution. A learnable unit‑norm task embedding eₖ conditions both encoder and decoder, enabling the VAE to specialize while still benefiting from shared parameters when tasks are clustered together. The clustering is performed during a warm‑up phase: gradients of each task’s loss are collected, and tasks with similar gradient directions are grouped, fixing the assignment iₖ for the remainder of training.
Temporal module – The dynamics model receives the sequence of latent codes and actions. For each expert j, an MLP m_{ϕ,j} fuses (zₜᵏ, aₜᵏ, eₖ) into a token m_{t,j,k}. A router network, taking eₖ as input, produces softmax scores Sₖ over all Nₑ experts; a Top‑K operation selects nₖ experts Jₖ and associated normalized weights Wₖ. Selected expert Transformers f_{ϕ,j} process their tokens independently; their outputs are concatenated into l₁:ₜᵏ, which is then fed to a shared Transformer F_{ϕ} together with eₖ, yielding a hidden state hₜᵏ. This design ensures that within a given task the same set of experts is active throughout an episode, preserving temporal coherence that token‑level routing would destroy.
Prediction heads – From hₜᵏ the model predicts (i) the distribution of the next latent state Ẑ_{t+1}ᵏ, (ii) the scalar reward ˆrₜᵏ, (iii) a binary continuation flag ˆcₜᵏ, and (iv) an auxiliary task label ˆk. The auxiliary task prediction loss forces hₜᵏ to be discriminative with respect to the task, which improves the quality of imagined rollouts.
Loss function – The total per‑task loss combines reconstruction (L_rec), reward (L_rew), continuation (L_con), task prediction (L_task), a KL‑based dynamics loss (L_dyn) and a reverse‑KL regularizer (L_rep). Hyper‑parameters β₁=0.5 and β₂=0.1 balance dynamics learning against representation learning. Temperature annealing in the router’s softmax gradually shifts routing from stochastic (encouraging balanced expert usage early) to near‑deterministic (allowing experts to specialize) as training proceeds.
Experiments – MoW was evaluated on two demanding benchmarks. On Atari‑100k, a single MoW agent trained jointly on 26 games for 100k environment steps per game achieved a mean human‑normalized score of 110.4 %, comparable to STORM’s 114.2 % which uses 26 separate task‑specific models. Crucially, MoW required roughly half the parameters of STORM. On Meta‑World, MoW reached a 74.5 % average success rate within 300 k steps, surpassing the previous state‑of‑the‑art. Ablation studies demonstrated that removing any of the three core components (task‑specific VAEs, task‑level routing, or gradient‑based clustering) leads to noticeable performance drops, confirming their complementary contributions.
Discussion – MoW’s modularity provides a principled way to balance parameter sharing and task specialization. The task‑level router avoids the fragmentation problem of token‑level MoE, while the warm‑up clustering automatically discovers groups of tasks that can safely share visual encoders and auxiliary predictors, dramatically reducing the overall model size. Limitations include the need to pre‑define the number of experts Nₑ and the Top‑K value nₖ, and the fact that task embeddings are learned but fixed after warm‑up, which may hinder zero‑shot generalization to completely unseen tasks.
Conclusion and future work – Mixture‑of‑World Models present a scalable, parameter‑efficient foundation for generalist world models in visual MTRL. Future directions include meta‑learning of task embeddings for rapid adaptation, online clustering to accommodate a continuously expanding task suite, and extending the architecture to handle partial observability and stochastic dynamics more explicitly. By unifying high‑fidelity visual reconstruction with expert‑driven latent dynamics, MoW bridges the gap between sample‑efficient model‑based RL and the diverse demands of real‑world multi‑task environments.
Comments & Academic Discussion
Loading comments...
Leave a Comment