Object-Centric World Models for Causality-Aware Reinforcement Learning
📝 Abstract
World models have been developed to support sample-efficient deep reinforcement learning agents. However, it remains challenging for world models to accurately replicate environments that are high-dimensional, non-stationary, and composed of multiple objects with rich interactions since most world models learn holistic representations of all environmental components. By contrast, humans perceive the environment by decomposing it into discrete objects, facilitating efficient decision-making. Motivated by this insight, we propose \emph{Slot Transformer Imagination with CAusality-aware reinforcement learning} (STICA), a unified framework in which object-centric Transformers serve as the world model and causality-aware policy and value networks. STICA represents each observation as a set of object-centric tokens, together with tokens for the agent action and the resulting reward, enabling the world model to predict token-level dynamics and interactions. The policy and value networks then estimate token-level cause–effect relations and use them in the attention layers, yielding causality-guided decision-making. Experiments on object-rich benchmarks demonstrate that STICA consistently outperforms state-of-the-art agents in both sample efficiency and final performance.
💡 Analysis
World models have been developed to support sample-efficient deep reinforcement learning agents. However, it remains challenging for world models to accurately replicate environments that are high-dimensional, non-stationary, and composed of multiple objects with rich interactions since most world models learn holistic representations of all environmental components. By contrast, humans perceive the environment by decomposing it into discrete objects, facilitating efficient decision-making. Motivated by this insight, we propose \emph{Slot Transformer Imagination with CAusality-aware reinforcement learning} (STICA), a unified framework in which object-centric Transformers serve as the world model and causality-aware policy and value networks. STICA represents each observation as a set of object-centric tokens, together with tokens for the agent action and the resulting reward, enabling the world model to predict token-level dynamics and interactions. The policy and value networks then estimate token-level cause–effect relations and use them in the attention layers, yielding causality-guided decision-making. Experiments on object-rich benchmarks demonstrate that STICA consistently outperforms state-of-the-art agents in both sample efficiency and final performance.
📄 Content
Deep reinforcement learning (RL) has achieved success in a variety of fields (Mnih et al. 2013;Lillicrap et al. 2015;Mnih et al. 2016;Schulman et al. 2017;Haarnoja et al. 2018), including robotics (Levine et al. 2016;Andrychowicz et al. 2019;Kalashnikov et al. 2018) and autonomous driving (Sallab et al. 2017;Isele et al. 2017;Kendall et al. 2019). However, achieving high performance requires extensive interactions with the environment, which is costly and inefficient in the real-world tasks when considering real-time operation and physical device failures (Moerland et al. 2023).
Humans understand and predict real-world dynamics through interaction with the environment (Sutton and Barto 1981;Friston 2010). Inspired by this mechanism, world models were proposed for Model-Based RL (MBRL) (Sutton 1990;Ha and Schmidhuber 2018). In this setting, agents train world models to replicate their environments (observations and actions) and optimize their policies within “imag-ined” environments generated by the world models, making their learning sample-efficient. Earlier MBRL agents have adopted Recurrent Neural Networks (RNNs) as the dynamics model of the world model (Hafner et al. 2020(Hafner et al. , 2021(Hafner et al. , 2025;;Kaiser et al. 2020;Deng, Jang, and Ahn 2022), whereas recent works have begun to explore the use of Transformers (Vaswani et al. 2017;Robine et al. 2023;Micheli, Alonso, andFleuret 2023, 2024;Zhang et al. 2023;Burchi and Timofte 2025). Compared with RNNs, Transformers provide superior learning efficiency, generalization performance, and long-term prediction accuracy.
However, even with Transformers, world models still struggle to replicate environments that are highdimensional, non-stationary, and composed of multiple objects with their interactions, while such environments are common in real-world applications, including service robots and autonomous driving. This is because world models learn holistic representations of the environments, which may fail to capture the important relationships and interactions between individual objects (Santoro et al. 2017). When placed in such environments, humans perceive the environment by decomposing it into discrete concepts such as objects and events, enabling more efficient and causality-aware decision-making (Spelke and Kinzler 2007). Incorporating these cognitive mechanisms into world models potentially allow RL agents to operate more effectively even in complex settings.
Motivated by this insight, we propose Slot Transformer Imagination with CAusality-aware reinforcement learning (STICA), depicted in Figure 1. This is an RL agent built upon a unified framework in which the world model, policy network, and value network are all implemented using object-centric Transformers, and both the policy and value networks explicitly leverage causal information for more structured and effective decision-making. The slotbased autoencoder extracts object-centric representations (z 1 t , . . . , z n t ) from observations o t while excluding static background information z BG , as shown in Figure 1 (a) and (b). The Transformer-based dynamics model accepts these representations (z 1 t , . . . , z n t ), along with agent actions a t and obtained rewards r t-1 , as input tokens and predicts token-level dynamics and interactions. The policy and value networks estimate causal relationships G among the input tokens, thereby facilitating policy learning for causality- t , . . . , z n t ) from observation o t , excluding static background information z BG at time t (1 ≤ t ≤ T ). Transformer-based dynamics model computes hidden states (h 1 1:t , . . . , h n 1:t ) and h ′ 1:t from latent states (z 1 1:t , . . . , z n 1:t ), actions a 1:t , and rewards r 1:t-1 , followed by the multilayer perceptrons (MLPs) that predict the next latent states (ẑ 1 t+1 , . . . , ẑn t+1 ), the reward rt , and the discount factor γt . (b) Examples of object-centric representations for Safety Gym benchmark task; the observation o t , its reconstruction ôt , the reconstructions from the extracted object-centric latent states (z 1 t , . . . , z 5 t ), and that from the static background information z BG . (c) Causal policy and value networks. They estimate causal relationships from the latent states to the action or value, based on a causal graph G and causality scores p k t , and adjust the attention weights within the Transformers accordingly, enabling the causality-aware decision-making. The latent states (z 1 t , z 2 t , and z 5 t ) of goal-related objects or obstacles are expected to have stronger causal influence on the target token (a ′ t or v ′ t ), while latent states (z 3 t and z 4 t ) of objects irrelevant to task completion have weaker causal influence.
aware decision-making, which is aware of “causality” in the sense of token-level dependency (not in the context of causal inference) (see Figure 1 (c)). The main contributions of this paper can be summarized as follows:
High-perform
This content is AI-processed based on ArXiv data.