Causal-JEPA: Learning World Models through Object-Level Latent Interventions
World models require robust relational understanding to support prediction, reasoning, and control. While object-centric representations provide a useful abstraction, they are not sufficient to capture interaction-dependent dynamics. We therefore propose C-JEPA, a simple and flexible object-centric world model that extends masked joint embedding prediction from image patches to object-centric representations. By applying object-level masking that requires an object’s state to be inferred from other objects, C-JEPA induces latent interventions with counterfactual-like effects and prevents shortcut solutions, making interaction reasoning essential. Empirically, C-JEPA leads to consistent gains in visual question answering, with an absolute improvement of about 20% in counterfactual reasoning compared to the same architecture without object-level masking. On agent control tasks, C-JEPA enables substantially more efficient planning by using only 1% of the total latent input features required by patch-based world models, while achieving comparable performance. Finally, we provide a formal analysis demonstrating that object-level masking induces a causal inductive bias via latent interventions. Our code is available at https://github.com/galilai-group/cjepa.
💡 Research Summary
Causal‑JEPA (C‑JEPA) introduces a simple yet powerful inductive bias for object‑centric world modeling by masking at the level of object slots rather than image patches. A frozen Slot‑Attention encoder extracts a fixed set of N object embeddings from each video frame. During training, a random subset of these object slots is masked across a history window, leaving only the earliest observation of each masked object as an “identity anchor”. The masked token is formed by linearly projecting this anchor and adding a learnable temporal embedding, which can be interpreted as a latent‑level intervention. A bidirectional Vision‑Transformer predictor then jointly reconstructs the masked history tokens and forecasts future object slots, optionally conditioned on auxiliary variables such as actions and proprioception. The loss combines a masked‑latent prediction term (L2 distance on all masked tokens) with a forward‑prediction term.
The authors prove that object‑level masking forces the model to rely on interactions with other entities, because a masked object’s future state cannot be inferred from its own past alone. This eliminates shortcut solutions such as simple temporal interpolation and embeds a causal bias directly into the learning objective.
Empirically, C‑JEPA is evaluated on two fronts. On the CLEVRER video‑question‑answering benchmark, it yields an absolute ~20 % improvement on counterfactual questions compared with the same architecture without object‑level masking, demonstrating stronger relational reasoning. In a model‑predictive control (MPC) setting on the Push‑T manipulation task, C‑JEPA achieves comparable success rates to a strong patch‑based world model (DINO‑WM) while using only about 1 % of the total input token count, resulting in more than an 8× speed‑up in planning.
Overall, C‑JEPA shows that a modest change—masking whole object representations—can simultaneously improve sample efficiency, computational cost, and causal generalization in world models, making it a compelling approach for both visual reasoning and control applications.
Comments & Academic Discussion
Loading comments...
Leave a Comment