๐ Original Info
- Title:
- ArXiv ID: 2512.20605
- Date:
- Authors: Unknown
๐ Abstract
Google, * Core contributor. Large-scale autoregressive models pretrained on next-token prediction and finetuned with reinforcement learning (RL) have achieved unprecedented success on many problem domains. During RL, these models explore by generating new outputs, one token at a time. However, sampling actions token-by-token can result in highly inefficient learning, particularly when rewards are sparse. Here, we show that it is possible to overcome this problem by acting and exploring within the internal representations of an autoregressive model. Specifically, to discover temporally-abstract actions, we introduce a higher-order, non-causal sequence model whose outputs control the residual stream activations of a base autoregressive model. On grid world and MuJoCo-based tasks with hierarchical structure, we find that the higher-order model learns to compress long activation sequence chunks onto internal controllers. Critically, each controller executes a sequence of behaviorally meaningful actions that unfold over long timescales and are accompanied with a learned termination condition, such that composing multiple controllers over time leads to efficient exploration on novel tasks. We show that direct internal controller reinforcement, a process we term "internal RL", enables learning from sparse rewards in cases where standard RL finetuning fails. Our results demonstrate the benefits of latent action generation and reinforcement in autoregressive models, suggesting internal RL as a promising avenue for realizing hierarchical RL within foundation models.๐ Full Content
RL efficiency can be greatly increased by starting from an autoregressive sequence model that has been pretrained on a wide range of behaviors, such as a large language model (LLM). From an RL standpoint, self-supervised pretraining can be seen as imitation learning under partial observability, where not only is noise introduced and intermediate steps occluded, but also latent variables, such as task descriptors, agent rewards and goals, and their mental states, are unknown. This setup imbues the resulting models with latent variable inference capabilities [4,5] (commonly referred to as in-context learning [6]) that allow adapting to new tasks and environments quickly. Moreover, pretrained autoregressive models serve as rich action priors from which diverse, meaningful sequences can be sampled, enabling efficient exploration from the start.
Efficient, long-horizon exploration is key for RL to succeed, in particular when rewards are sparse. This leads us to an important problem that autoregressive models face: because these models produce sequences one token at a time, RL exploration is driven entirely by token-level variations. However, solely relying on token-by-token variability to explore can be insufficient to make progress on hard, sparse-reward problems which require generating multiple tokens correctly before obtaining a reward. This observation, which is at the center of the present study, has motivated a long line of research on hierarchical RL. Hierarchical RL attempts to exploit the fact that real-world problems are typically amenable to a hierarchical approach, wherein a final solution is expressed in terms of temporallyabstract actions -i.e., reusable subroutines that run for extended time periods (sometimes called “options”) [7]. Evidence suggests that humans approach problem solving using such temporal abstractions [8], which implies that this may be a very efficient way to learn. Importantly, if temporally-abstract subroutines exist, exploration can occur at higher levels of temporal abstraction, drastically reducing the search space relative to token-by-token exploration. However, discovering appropriate subroutines via deep RL remains a longstanding challenge. While policy gradient methods have been derived (e.g., the option-critic [9]), these approaches have theoretical issues and tend to fail in practice, often converging to degenerate options [10].
In this paper, we pursue an alternative approach for temporally-abstract action discovery that builds directly upon autoregressive modeling. Based on their in-context latent variable inference capabilities, we hypothesize that autoregressive action models implicitly learn temporallyabstract actions represented in their internal activations, despite being trained to predict only one token at a time. This hypothesis leads us to introduce an internal neural network controller in charge of steering the internal activations of a base model. Critically, the controller learns through an unsupervised variational inference algorithm [11][12][13][14], which does not require per-time-step abstract action labels, in contrast to standard model steering techniques [15,16]. We evaluate our approach on a family of RL tasks that are constructed in a hierarchical, compositional manner. We consider both a classic discrete grid world environment [17,18], and a more challenging hierarchical continuous control environment implemented on the MuJoCo physics simulator [19]. The latter requires an agent to master both low-level continuous motor control as well as planning at a higher level of temporal abstraction to exploit the underlying discrete, compositional task structure. We find that the internal controller discovers how to generate higherorder sequences of temporally-abstract actions that switch sparsely in time. These abstract actions enable efficient exploration by drastically reducing the search space size in novel tasks and simplify credit assignment by reducing the effective time horizon of the policy. The final product is a novel hierarchical RL method that directly reinforces internal activations to solve sparse reward tasks that tokenlevel approaches cannot solve. Our results demonstrate the benefits of latent action generation for RL applied to pretrained autoregressive models.
We illustrate our approach in Fig. 1, and preview our main contributions below:
โข Next-action predictors inherently develop temporally-abstract action representations. We analyze transformers and state-space models (SSMs) trained to autoregressively predict the actions of goal-directed agents, whose goals are unknown. We find that the networks learn to represent (and infer in-context) a belief about an agent’s goals in their residual stream activations.
โข Linearly controllable temporally-abstract actions. These temporally-abstract representations are also easily controllable: a linear residual stream controller near mid-depth suffices to turn the sequence model into a closed-loop goal-optimizing policy, capable of executing a long-horizon plan.
โข Compositional generalization in the residual stream. We show that such controllers can be se-quenced in time. Residual stream controller sequencing enables compositional generalization, yielding agents that combine multiple goals in ways not seen during training.
โข A new neural architecture for autoregressive model control, which discovers temporally-abstract actions without supervision. We develop a metacontroller neural network that reads from the sequence model residual stream, and in return applies a linear controller to it. The metacontroller learns to generate goal-optimizing controllers that exhibit temporal abstraction: it keeps applying the same controller for a variable number of time steps before switching to a new one. To discover appropriate temporally-abstract actions without any supervision signals, our method relies on two key properties: (i) reading from and writing back to the residual stream of a pretrained autoregressive model, and (ii) future-conditioning: during training, the metacontroller is non-causal, and is conditioned on a sequence embedding obtained by performing a first pass through the entire sequence.
โข A new “internal RL” paradigm, many orders of magnitude faster than standard RL finetuning in hierarchically-structured tasks. We introduce internal RL: performing RL directly within the residual stream of the base model, taking internal activations as observations and metacontroller outputs as actions.
We show that internal RL significantly outperforms both standard RL finetuning as well as a strong prior hierarchical RL method [CompILE; 17], achieving both higher initial success rates and more efficient credit assignment than the baseline methods in hierarchicallystructured tasks.
Before diving into the description of our internal RL model, we first analyze the internal activations of autoregressive models pretrained to predict the behavior of goal-directed agents. Our goal here is to verify that a model trained on next-token prediction can learn temporally-abstract actions in its internal activations that we can leverage for internal RL. To do this, we pretrain our models from scratch on a behavioral dataset ๐ท comprising observation-action sequences produced by different expert agents that solve tasks via stochastic policies of varying degrees of optimality. The autoregressive model can thus be thought of as a sequence model of likely observation-action trajectories. Each element of ๐ท is a sequence (๐ 1 , ๐ 1 , . . . , ๐ ๐ , ๐ ๐+1 ) comprised of the initial sensory observations ๐ 1 , actions ๐ ๐ก taken by an agent and resulting sensory observation ๐ ๐ก+1 at time steps ๐ก โ {1, . . . , ๐ }. Like behavioral datasets collected at scale (e.g., those used to train LLMs), ๐ท does not contain rewards, nor any explicit agent goal and task descriptors.
The analyses presented in this section seek to determine if, and how, autoregressive models infer abstract patterns in long-horizon, goal-directed action sequences. We collect behavior from two classes of environments where agents perform navigation tasks. Importantly, the tasks are hierarchically-structured (cf. Fig. 2): though basic movement skills are a prerequisite, any given task can be solved with a combination of sub-routines composed of common sequences of basic movements. More concretely, we study both a discrete grid world environment that was previously introduced as a testbed for hierarchical RL [17,18], as well as a continuous-observation, continuous-action adaptation implemented by us in the Mu-JoCo physics simulator [19], where a quadrupedal robot (the ‘ant’ [20,21]) must be controlled at joint-level. In both environments, an agent needs to follow a course that arrives at certain colored locations in a specific order. In other words, the agents need to navigate between subgoals while also ignoring distractors (non-goal colored locations), all while avoiding collisions with randomly placed walls. Any task is described by a sequence of subgoals, which are either a single colored location for the ant, or two consecutive colored locations for the grid world. A given task can be mapped to different spatial configuration of the subgoals, the distractors, and the walls, see Appendix A for more details on the environments. In these environments, abstract actions are equivalent to moving towards a specific subgoal, hence we use the terms “abstract action” and “subgoal” interchangeably in this paper.
Given behavioral data collected for a set of easy tasks, referred to as pretraining tasks set (see Appendix A and C.1 for more details on the tasks and how the behavioral data are collected), we proceed with autoregressive sequence model pretraining, here a standard causal transformer [1] for discrete grid world data, and an efficient SSM (Hawk [22]) for ant control data. The models are pretrained from scratch by minimizing the cross-entropy
with ๐ ๐ the sequence model, and ๐ its parameters. For the case of continuous actions, the likelihood ๐ ๐ (๐ ๐ก |๐ 1:๐ก ) is modeled as a Gaussian with learned diagonal covariance matrix. For discrete actions, the likelihood is parameterized as a categorical distribution with probabilities provided by the (a) To complete a task, an agent must visit in sequence a number of subgoal locations, each marked with a specific color. The tasks are performed either in a discrete grid world or in a continuous motor control environment, illustrated above, where a quadrupedal robot (the ‘ant’) must be actuated at joint level. A task can be described as an abstract action sequence (the subgoal locations that must be visited), or as a sequence of low-level motor commands. (b) We pretrain autoregressive action models and metacontrollers on unlabeled behavioral datasets containing observation-action sequences of expert agents performing different tasks. These sequences do not contain rewards or subgoal labels. We then test the ability of the models to learn with RL tasks that comprise longer subgoal sequences, combined in new orders not seen during pretraining and metacontroller training. softmax over the output logits. Note that while the main objective here is behavioral (next-action) prediction, the models are also trained on next-observation prediction, the objective of world (dynamics) modeling [23][24][25]. The weight of this auxiliary loss is determined by a scalar hyperparameter ๐ โฅ 0; we analyze its role in the Appendix Fig. A2. Additional optimization and architectural details may be found in Appendix C and D.
To determine whether the internal activations of the pretrained autoregressive models learn to identify temporal abstractions related to the subgoals, we analyze the internal activations of the models using two common mechanistic interpretability techniques [26], linear probing [27] and causal model intervention [28,29]. For the former (linear probing), we train linear classifiers to decode the agent subgoals ๐ ๐ก โ {1, . . . , ๐บ} on the grid world environment from the instantaneous (time step ๐ก) residual stream activation vector ๐ ๐,๐ก โ โ ๐ ๐ after the ๐-th model block. Fig. 3 shows that linear decoder probability mass concentrates on the correct latent subgoal as time ๐ก increases, i.e. as more evidence about the current agent subgoal is gathered. Moreover, linear decoding likelihood increases with layer depth ๐, peaking close to the final embedding used by the transformer decoder. Thus, despite being trained only on one-step action prediction, the autoregressive models learn to represent temporally-abstract subgoals. This result is in line with the infinite-data theory of in-context Bayesian inference in sequence predictors [30], and adds more evidence to the linear representation hypothesis in neural sequence models [31][32][33]. For causal model intervention, we ask whether the internal representations of the autoregressive model can be leveraged to create a subgoal-optimizing policy. Inspired by the effectiveness of LoRA finetuning [34], we introduce a low-rank linear residual stream controller with parameters ๐ โ โ ๐ ๐ ร๐ ๐ , which modifies the instantaneous residual stream activations in between model blocks at a given depth ๐ following the update ๐ ๐ก,๐ โ ๐ ๐ก,๐ + ๐ ๐ก ๐ ๐ก,๐ .
(1)
Note that we allow the controller parameters ๐ ๐ก to vary in time. In this section, we maintain a set of ๐บ separate controllers {๐ (๐) } ๐บ ๐=1 , one per subgoal, and manually select which controller ๐ ๐ก to apply at every time step ๐ก using the groundtruth subgoal label ๐ ๐ก . (We will eliminate the use of ground-truth subgoal labels later on.) To train the controllers, we condition generation upon the correct subgoalspecific controller ๐ (๐) , and minimize the cross-entropy (๐ 1:๐+1 ,๐ 1:๐ )โผ๐ท * ๐ก -ln ๐ ๐,๐ (๐ ๐ก |๐ 1:๐ก , ๐ ๐ก ) w.r.t. controller parameters ๐ (while holding ๐ fixed) on a behavioral dataset ๐ท * . This dataset contains behavioral sequences that are generated in the same way as those in the pretraining dataset ๐ท, but with increased optimality, see Appendix C.4.4. Here and throughout, ๐ refers to controller parameters that were not part of the pretrained model ๐ ๐ , and ๐ ๐,๐ denotes a controlled model.
We evaluate the subgoal-optimizing controllers on a posttraining OOD task set that requires both length and compositional generalization: as shown in Fig. 2 and detailed in Appendix A, the post-training tasks recombine subgoals in orders not seen neither during pretraining nor controller training. As well, they comprise longer subgoal trajectories. Fig. 4 shows that these novel tasks can be solved with a high success rate by simply activating the corresponding subgoal controllers in the correct order, without any autoregressive sequence model retraining. More detailed descriptions of these mechanistic interpretability experiments and some additional experimental results are presented in Appendix B and C.
Our analysis further reveals a distinction between latent variable belief state representation (at least w.r.t. a linear decoder) and internal representation control. Whereas linear subgoal decoding is possible from mid-depth up until the final layer, subgoal-conditioning is best achieved by inserting a linear controller in the middle of the pretrained sequence model, see Fig. 4. There is an intuitive appeal to this result: the mapping from abstract subgoals spanning many time steps to actual per-time-step low-level actions is implemented over multiple model layers. Our findings join two recent studies [35,36] that identify the first half of language models as the strongest for transfer learning, and as exerting the strongest influence on predicting future tokens. Given these results, in what follows, and unless noted otherwise, controllers always read from and write back to the residual stream at mid-depth of the autoregressive sequence model.
The analyses above show that simple internal activation controllers can steer a pretrained next-action sequence model to execute temporally-abstract actions, here navigation to a sequence of subgoals. We have so far assumed access to subgoal labels, similarly to how current model steering methods [37] are trained using detailed supervision information (e.g., on the truthfulness of an answer [38] or on personality traits [39]). We now turn to the challenging unsupervised setting with no groundtruth labels, where the model must both discover temporally-abstract actions from an unlabeled behavioral dataset ๐ท, and learn a selection mechanism that generates appropriate sequences of subgoals, and related abstract actions, in order to achieve a larger goal.
To simultaneously learn abstract actions and orchestrate a b In both grid world (a) and ant (b) environments, inserting the controller near the middle layer results in better controllability, as measured by the success rate on the post-training tasks, which require both length and compositional generalization. To produce this analysis, we trained one controller per subgoal using groundtruth labels; to evaluate success rates we activated the controllers in correct order, again using groundtruth subgoal labels. Results averaged over 5 seeds.
their execution, we freeze the autoregressive model after training on ๐ท, then we augment it with a metacontroller that can generate the controllers, ๐ ๐ก , for the residual stream activations in the sequence model. As before, we continue training on ๐ท * with ๐ fixed. But, now, we do not condition the controller on the groundtruth subgoal -instead the metacontroller learns how to generate the appropriate controllers at the appropriate times. We describe the model in full in Appendix D.2, and illustrate it in Fig. 5. Briefly, the metacontroller is a generative stochastic recurrent neural network with an encoder-decoder architecture that enables sampling controllers sequentially. Because it outputs the parameters ๐ ๐ก of a controller and not directly a control vector, the metacontroller can be qualified as a recurrent hypernetwork [40]. The decoder is a feedforward network that produces a controller, ๐ ๐ก , from a controller code, ๐ง ๐ก . The encoder is a recurrent network based on the gated recurrent unit [41] that specifies the mean ๐ ๐ก and variance ฮฃ ๐ก of a Gaussian distribution over a random controller code z๐ก โผ N (๐ง enc ; ๐ ๐ก , ฮฃ ๐ก ). Importantly, the encoder is non-causal, because it receives an embedding, ๐ (๐ 1:๐ ), of the whole sequence of latent activities. We justify such future-conditioning using a formal latent variable modeling argument in Appendix E.1. Additionally, the metacontroller includes a recurrent switching unit, that operates between the encoder and decoder. This unit determines a time-varying continuous switching gate ๐ฝ ๐ก โ [0, 1], which controls the interpolation between previous controller code ๐ง ๐ก-1 and a new sampled code z๐ก :
where โ denotes elementwise multiplication. Despite its simplicity, this temporal integrator is critical for the metacontroller to learn to generate the appropriate temporally-abstract actions, as we will confirm through ablation experiments at the end of this section.
The metacontroller parameters ๐ are trained through the minimization of a self-supervised learning objective, comprising (low-level) next-action prediction and an additional prior-matching regularizer,
where ๐ท KL (โข โฅ โข) denotes the Kullback-Leibler divergence [42]. The inclusion of this regularizer (with weight determined by the hyperparameter ๐ผ โฅ 0) promotes the generation of meaningful sequences when sampling controller codes ๐ง ๐ก from a standard normal distribution, a property that we exploit in the next section to develop a novel hierarchical RL algorithm. From an information-theoretic perspective, ๐ผ also controls the variational bottleneck by regulating the information flow from the acausal encoder to the controller. As shown in our later analysis, this bottleneck is instrumental in driving the model toward sparse, subgoal-aligned switching patterns that mirror the underlying task structure. Moreover, the choice of an unconditional prior (i.e., where next abstract action proposals are independent of past ones) promotes the development of compositional representations, which match well our hierarchical tasks. In Appendix E.1, we derive Eq. 3 formally using a variational information-theoretic approach [43].
The derivation is standard, and follows closely previous calculations for stochastic recurrent models [e.g., 44,45].
Ultimately, the metacontroller both discovers the temporally-abstract actions that underlie the observed agents’ behavior, and learns to sequence them appropriately in time by implementing respective termination conditions via the switching gate. In Fig. 6 andA3, we analyze the residual stream controllers discovered by the metacontroller by plotting the switching gate values ๐ฝ ๐ก against groundtruth abstract actions ๐ ๐ก . We find that the metacontroller recovers the groundtruth abstract action switching times. After training, the switch gate learns to behave in a quasi-binary, sparsely-switching fashion, despite not being explicitly regularized to do so. This is a notable finding in light of the critical role that switching regularization meth- ods play in hierarchical RL [46], and given the simplicity of the temporal integrator (Eq. 2). The resulting temporal segmentation is essentially perfect, despite the fact that both observations and actions are continuous for the ant environment. Moreover, the metacontroller learns to generate latent controller codes which correspond to meaningful temporally-abstract actions (e.g., “go to color blue”), that generalize to new task configurations and switching times (see Appendix B.3.2 for an analysis).
We next study what happens when the autoregressive base model parameters ๐ are not kept frozen, and instead co-trained with metacontroller parameters ๐ through variational inference (the minimization of Eq. 3, now w.r.t. both ๐ and ๐). This baseline is conceptually close to previous hierarchical RL methods that use variational inference to learn abstractions from unlabeled demonstrations (e.g., [17,18]), while using our particular neural network architecture. To compare the abstract action representations developed when the base model is frozen vs. when it is not, we resort to a rate-distortion analysis [43], obtained by varying the value of the hyperparameter ๐ผ (which controls the rate-distortion trade-off in Eq. 3) over a wide interval, see C.6 for additional details. We trace rate-distortion curves for both our standard metacontroller (which steers a pretrained, frozen autoregressive model) and for the co-trained metacontroller, see Fig. 7.
Intriguingly, we find that a horizontal gap appears on the rate-distortion curve between metacontrollers with subgoal-aligned switching (with rate-distortion points marked by a โ symbol in Fig. 7), and those with slightly less rate. This indicates that at that rate level, a small increase in rate dramatically improves the distortion. In contrast, for the co-trained metacontroller, although the variational objective is minimized, this structure is lost. For most values of ๐ผ, the model converges to a degenerate solution characterized by a single switch at the very beginning of the sequence. The fact that subgoal-aligned switching corresponds to this improved distortion with frozen autoregressive models, but not with co-trained models, shows that pretraining builds an internal representation that aligns well with abstract actions. Furthermore, this also has optimization implications: for a given value of ๐ผ, the variational objective (Eq. 3) is minimized on the point of the rate distortion curve which has a tangent of slope -1/๐ผ. A gap like the above, with a slope discontinuity, indicates that for a large range of values of ๐ผ, the variational objective is minimized precisely at the region with subgoal-aligned switching. This analysis therefore confirms that controlling a frozen autoregressive action predictor is essential for the discovery of temporally-abstract actions.
Taken together, the results presented in this section provide strong evidence that our model can both learn temporally-abstract actions and how to sequence them appropriately, all in a self-supervised manner. We will see next how this model can be leveraged to speed up exploration in new, harder tasks by many orders of magnitude, enabling sparse-reward RL to succeed.
Finally, we consider the question of how to leverage our model to learn harder tasks through hierarchical RL. We study only the challenging sparse-reward setting, where b a d c Figure 7 | A rate-distortion analysis reveals the importance of the controlled, pretrained autoregressive model being frozen for the discovery of temporally-abstract actions. We compare our standard metacontroller, which steers a frozen base model (left column; a, c), with a metacontroller that is co-trained with the base model it is steering (right column; b, d). The x-axis represents action prediction loss (the distortion, or negative loglikelihood; NLL) and the y-axis represents the KL divergence to the prior (the rate). As the trade-off hyperparameter ๐ผ in Eq. 3 is swept over to trace the rate-distortion curve, it reveals a range of values for which correct subgoal switching representations develop (marked with a โ ) when the base model is frozen, but not for the co-training regime. This holds similarly for grid world (top row; a, b) and the ant environment (bottom row; c, d).
a single positive success reward is provided per trajectory, and only when an entire sequence of subgoals is correctly completed.
We begin this section by establishing that our tasks (described in Fig. 2) are difficult for standard RL approaches to post-training. We first study an adapted version of the GRPO algorithm [3], which is a strong baseline in the sparse-reward setting. The details of our GRPO implementation can be found in Appendix C.5.2. For the tasks considered here, training an agent from scratch directly with RL has, for all practical purposes, no chance of succeeding. Thus, to make the comparison fair, we instead apply GRPO to the pretrained autoregressive sequence model, as is now routinely done with LLMs. However, even with a pretrained sequence model that has been trained on action sequences related to the subgoals, there is only a minuscule chance (on the order of one in a million) of producing successful trajectories by random sampling at the output token-level. This causes GRPO training to fail, as the model does not receive enough signal to learn, see Fig. 8. An inspection of the action sequences generated by the autoregressive sequence model reveals that while the model reproduces action sequences seen in the training data, it fails to explore at a higher level of temporal abstraction, which would be required to solve these sparse reward RL tasks. In other words, simply training the sequence model with policy gradients does not lead the system to explore novel combinations of subgoals.
Having shown that standard post-training RL fails, we now introduce internal RL. The key step in internal RL is to treat the autoregressive sequence model as part of the environment; actions then correspond to residual stream interventions, ๐ข ๐ก , and observations correspond to residual stream activations, ๐ ๐ก,๐ . We note that performing RL at the residual stream level is a priori challenging. Consider the problem of learning from scratch a policy ๐(๐ข ๐ก | ๐ 1:๐ก ) whose outputs ๐ข ๐ก โ โ ๐ ๐ additively control the residual stream, ๐ ๐ก,๐ โ ๐ ๐ก,๐ + ๐ข ๐ก , without relying on error backpropagation to differentiate through the base model that is being controlled. This is a high-dimensional continuous control problem, an exceedingly difficult setting for RL [47].
Instead of directly attempting to learn a residual stream control policy, internal RL consists of doing RL in the controller code space of ๐ง, after the metacontroller is trained in a self-supervised manner, as described in the previous section. This approach assumes that the metacontroller has learned a meaningful switching unit ๐ switch , and a controller code space such that ๐ง ๐ก โผ N (0, ๐ผ) is a meaningful prior for sampling abstract actions. Intuitively, the metacontroller does not suffer from the drawbacks of directly doing RL in the residual stream for two reasons: (i) the action space dimension is reduced (๐ ๐ง < ๐ ๐ ), (ii) the metacontroller operates on an abstract timescale, dramatically reducing the time horizon for difficult environments. The latter is the key property that can enable internal RL to be more efficient and succeed on hierarchical, sparse reward tasks where standard RL methods fail. RL curves for various methods that leverage a pretrained autoregressive sequence model for the (a) discrete grid world environment, and (b) the ant continuous control environment. We compare our full-blown internal RL algorithm to a number of baselines: standard (raw action) RL finetuning; CompILE [17], a hierarchical RL method that also learns from unlabeled demonstrations, like ours; internal RL applied to a metacontroller that has been trained without a temporal integration unit (forced switching at every timestep, โ ๐ก ๐ฝ ๐ก = 1); and internal RL applied to a metacontroller that has been co-trained from scratch with an autoregressive action model, sidestepping the pretraining phase (see main text for more details). All baselines fail to learn within a million episodes. Lines and shaded area resp. report median and the spread between the 25 th and 75 th quantiles computed over 30 runs (3 metacontrollers trained for each of 10 pretrained models). We provide this figure in log-scale in Appendix Fig. A5 for a more detailed analysis of the failure modes of the baselines.
In more detail, internal RL consists in replacing an unsupervised controller encoder which uses privileged future information ๐ (๐ 1:๐ ) by a causal abstract action policy ๐(๐ง ๐ก | ๐ 1:๐ก ), and then training it through RL, while keeping all other modules and their parameters fixed. Conceptually, this amounts to subsuming the autoregressive model, as well as part of the metacontroller, into the environment (cf. Fig. 1). To generate discrete switching events, we further apply a threshold to binarize the switching rate, i.e., we replace ๐ฝ ๐ก in Eq. 2 by ๐ป ( ๐ฝ ๐ก -๐ฝ threshold ) with ๐ป the Heaviside step function and ๐ฝ threshold โ โ a hyperparameter. This way, until a switch signal (๐ฝ ๐ก = 1) is emitted by the metacontroller, the same abstract action is applied, thus allowing ๐ to operate on a temporally-abstract timescale. Pseudocode for the internal RL environment and algorithm is provided in Appendix C.5.1.
Fig. 8 shows that internal RL achieves a high success rate on the post-training task set. Leveraging the temporal abstractions discovered through self-supervised metacontroller learning is crucial for this success, as shown by the failure of a metacontroller for which the temporal integration unit is disabled (โ ๐ก ๐ฝ ๐ก = 1). To give this baseline a fair chance, this ablation is introduced during self-supervised metacontroller learning, not just when performing posttraining RL. We note that the ๐ฝ ๐ก = 1 ablation also achieves a high initial success rate; this can be seen when plotting success rates in log-scale (cf. Appendix Fig. A5). However, only our full-blown (temporally-abstract) internal RL both achieves high initial success rates and performs efficient credit assignment, such that RL succeeds. In Appendix E.2 we present a mathematical argument for the efficiency of credit assignment in internal RL, comparing the variance of the resulting policy gradients of internal against RL in raw action space.
Moreover, to evaluate the internal abstractions developed through autoregressive action modeling, we compare again to the co-trained baseline, where both metacontroller and base model are jointly optimized through the minimization of Eq. 3. Consistent with the rate-distortion analysis results (Fig. 7), the success rate of post-training internal RL remains close to zero. The same holds for CompILE [17], a comparable, previously proposed hierarchical RL method that also relies on variational inference to discover temporally-abstract actions from an unlabeled behavioral dataset. These results again confirm the importance of the initial autoregressive foundation model pretraining phase, followed by base model freezing, for enabling efficient hierarchical RL.
In this work, we asked whether the latent representations of autoregressive sequence models could be leveraged to develop RL techniques that overcome the inefficiency of token-by-token exploration and reinforcement. We studied this question using tasks that contain multiple subgoals that can be composed together to create the ultimate goal of the task. We first showed that an autoregressive sequence model trained on action and observation sequences from agents trained on simpler versions of the tasks learn representations in their hidden layers that carry information about the subgoals. Next, we demonstrated that these latent representations in the sequence model can be used by a set of internal controllers, provided with the groundtruth subgoals, to solve more complex tasks by compositionally generalizing in time. We then developed a model that uses a metacontroller to select appropriate temporally-abstract actions without receiving the groundtruth subgoal labels. Finally, we showed that directly reinforcing the internal activation controllers generated by the metacontroller enables learning in more complex, hierarchical sparse-reward tasks where other RL techniques fail. Altogether, our results demonstrate that the latent representations of autoregressive sequence models can indeed be leveraged to enable efficient, hierarchical RL.
There is a long-running debate on whether autoregressive next-token predictors can form consistent temporal abstractions and plans [48], with some researchers dismissing them as “stochastic parrots” [49]. Our work adds a positive piece of evidence to this question. We chose to study a set of RL environments that fulfill a few key properties we associate with intelligent agents. For an agent to master these environments, it must be able to (i) recombine previous behaviors in novel meaningful ways, (ii) learn from sparse rewards, and (iii) overcome reward sparsity by leveraging imitation learning to infer and repurpose the goal-directed behaviors of other agents. Learning from sparse rewards is arguably the ultimate setting for reinforcement learning, encompassing problem domains ranging from mathematical reasoning and robotic manipulation to scientific discovery in their most ambitious forms. Solving such tasks without reliance on manual reward shaping is a critical step toward autonomous agents capable of navigating complex, open-ended search spaces where the definition of intermediate progress is often unknown.
Despite their simplicity, the environments are challenging enough for standard RL methods to fail, including GRPO (a recent but by now standard method for sparsereward tasks), as well as CompILE, a previous hierarchical RL algorithm [17] that attempts to discover abstract actions from raw unlabeled data, instead of the internal representations of an autoregressive sequence model. The overwhelming success of internal RL over baseline RL algorithms reported here must still be taken with care, however, given the controlled nature of our experimental setup. Investigating and adapting internal RL to larger-scale models and tasks is an important direction of future work.
A number of prior analyses have probed the internal representations of autoregressive models, looking for temporal abstractions and plans. A recent exciting study provided compelling evidence for planning in LLMs asked to write rhyming poems [50], and earlier probing work found that hidden LLM states have some predictive power over a short number (four) of future tokens [51]. Another line of prior work has focused on models trained from scratch in controlled environments, as we do here, notably in games such as Othello [31,52] or chess [53,54]. To the best of our knowledge, we are the first to consider continuous environments with a hidden, discrete, hierarchical task structure. Despite being trained by gradient descent and only employing continuous units (both within the base SSM next-token predictor and the metacontroller) the models nonetheless discovered the underlying discrete latent task structure. In particular, the metacontroller developed sparse, quasibinary switching units. Moreover, our findings complement recent analyses of convolutional LSTM policies trained by end-to-end RL to play the Sokoban game [55,56]. These studies showed that RL led to the development of planning subroutines that unfold over multiple timesteps, like the goal-reaching policies that we found within self-supervised autoregressive models. We complement these studies by focusing on autoregressive transformers and SSMs trained on a next-token prediction objective, the current workhorse of artificial intelligence systems.
Schmidhuber theorized in a seminal paper [57] that a wake-sleep training loop iterating between training a history compressor through self-supervised learning (SSL), and letting a controller use the internal representations of the former to generate new experiences through RL, would lead to the acquisition of evermore complex capabilities, including the ability to form and exploit temporal abstractions and plans. Here, we provide both a concrete neural architecture following this philosophy, and a set of experimental results backing these claims. Interestingly, we begin to see the benefits of alternating between SSL and RL in large-scale models. For instance, DeepSeek-R1 [3] training also involved one iteration of the RL-SSL cycle, albeit with additional human curation involved in the (post-RL) SSL phase, and with RL still done at (raw) output action level.
Our model also displays similarities to LeCun’s joint embedding predictive architecture [JEPA; 58]. In particular, the metacontroller introduced here is similar to the JEPA configurator module, as both are in charge of modulating a general world model and policy in service of a given goal or task. However, JEPA is a proposal for learning abstract observation and action representations without an autoregressive predictive model, whereas next-action prediction is precisely at the center of our approach. In fact, we show that learning a (raw) action predictor is partly what enables discovering how to decompose a task into a sequence of subgoals, one of the open problems in the JEPA proposal.
The overwhelming advantage of internal RL over standard RL finetuning reported in this paper deserves further investigation in real-world environments. A direction that seems particularly worthy of pursuing is LLM reasoning. There is growing interest in reasoning methods that leverage the internal representations of LLMs for reasoning, mainly exploring recurrent iteration in neural activation space [e.g., [59][60][61][62]. The metacontroller model presented in our paper is complementary to these efforts, and may itself benefit from additional recurrence. Instead, the key innovation lies on the discovery of latent variables that compress time dynamically. This has the potential to cut the search space in a reasoning problem and thereby increase RL efficiency, as it did in a dramatic way in the problems considered here. A first step in this direction was taken by Kong et al. [63], who pretrained through variational methods a language model with a stochastic latent variable, and already saw promising results on reasoning benchmarks.
Finally, our results open a new avenue for model interpretability and control at scale. Similarly to sparse autoencoders (SAEs), a popular method for model interpretability and steering, the metacontrollers introduced in this work can be trained through scalable self-supervised learning and employ an encoder-decoder-type architecture. However, the two models otherwise have significant differences. While SAEs are trained on instantaneous internal activation reconstruction, metacontrollers are predictive and interventive, trained to directly lower output next-token prediction error by intervening on the residual stream. Moreover, they maintain internal state, whereas SAEs are instantaneous. Metacontrollers are thus by design likely better suited if the goal is foundation model control, and they offer the possibility of discovering interpretable interventions that run over an extended period of time. We are excited about the prospect of investigating whether these capabilities translate to larger-scale models such as LLMs.
Our grid world environment, referred to as gridworldpinpad in the Appendix, is inspired by the previously proposed visual Pin Pad benchmark [64]. In our version, an agent is located in a grid world, together with uniquely colored cells (also referred to as objects). Within a task, the agent needs to step on a sequence of colored cells in a task-specific order.
โข Task: A task is specified by a sequence of colored cells to visit.
โข State: The world is a 2D grid of size ๐บ-by-๐บ. There are ๐ unique colored cells placed on the grid, as well as ๐ walls. At any given moment, the agent occupies one of the ๐บ 2 -๐ cells that are not wall cells. Finally, the environment state also keeps track of what colored cells the agent has visited so far in the episode.
โข Action: There are 4 actions corresponding to the 4 cardinal directions.
โข Dynamics: Given the action and the agent position, the agent moves to the corresponding direction, except when it is moving towards a wall cell or outside of the grid, in which case the action results in a no-op. A colored cell is considered visited when the agent moves onto the cell from a different cell. If the agent successfully visits all colored cells in the right order, or if the agent visits a colored cell that is not the next cell specified by the task, or if the episode lasts longer than ๐ steps, the episode ends.
โข Initial state: At the beginning of every episode, the colored cells and walls, as well as the initial agent position are randomly sampled on the grid, ensuring there is no overlap.
โข Observation: The agent’s observation is the one-hot encoding of which object/wall is present in each cell, as well as the one-hot vector corresponding to the position of the agent, resulting in a ๐บ 2 (๐ + ๐ + 1)dimensional vector.
โข Reward: The agent gets a reward of 1 when successfully completing the task, and 0 otherwise.
For both pretraining and post-training tasks, we use ๐บ = 7, ๐ = 8, ๐ = 4, and ๐ = 100.
Numbering the colors from 0 to 7, the list of pretraining tasks can be found in Table A1. In this setup, the abstract subgoals combined to comprise the compositional final tasks, are given by 0 -1, 2 -3, 4 -5, and 6 -7.
We choose the post-training task to be 0 -
Ant-pinpad is a continuous control counterpart of the aforementioned gridworld-pinpad. The agent controls the classic MuJoCo ant [20], with the goal of stepping on a sequence of colored cells in a task-specific order. 0-1-4-5-0-1 0-1-4-5-2-3 0-1-6-7-2-3 2-3-0-1-4-5 2-3-6-7-2-3 2-3-6-7-4-5 4-5-0-1-4-5 4-5-0-1-6-7 4-5-2-3-6-7 6-7-2-3-0-1 6-7-2-3-6-7 6-7-4-5-0-1 0-1-6-7-4-5 2-3-0-1-6-7 4-5-2-3-0-1 6-7-4-5-2-3
โข Task: A task is specified by a sequence of colored cells to visit.
โข State: The state is a 2D plane, divided into grids. The grid is organized identically to that of the gridworldpinpad, and also includes colored cells and walls. The state is further augmented by the proprioception state of the ant, as well as the precise coordinate of the center of the ant in the grid. Finally the environment state also keeps track of what colored cells the agent has visited so far in the episode.
โข Action: The action is an 8-dimensional continuous vector representing the torque applied to the ant’s eight joints.
โข Dynamics: Given the action, the ant moves on the 2D plane as usual. When the center of the ant enters a wall cell or whenever the vertical position of the ant’s torso falls outside the valid operational range of [0.2, 1.0], an episode is instantly terminated. A colored cell is considered visited when the ant enters the cell from a different cell. If the agent successfully visits all colored cells in the right order, or if the agent visits a colored cell that is not the next cell specified by the task, or when the episode lasts longer than ๐ timesteps, the episode ends.
โข Initial state: At the beginning of every episode, the colored cells and walls, as well as the initial agent position are randomly sampled on the grid, ensuring there is no overlap. We initialize the agent’s full MuJoCo state by first setting the torso’s ๐ฅ, ๐ฆ position in the plane to the center of the sampled grid cell.
Then we add uniform noise that positions the agent in the simulation anywhere within the boundaries of the initial grid cell. We furthermore sample a random yawrotation and turn the agent correspondingly. Finally, the initial angles for all joints and initial velocities are sampled uniformly at random within a small range of 0.1 units around zero. โข Observation: The observation consists of the usual proprioception senses of the ant (to which the symlog function was applied, to ensure no excessively large values occur), concatenated with the global ๐ฅ, ๐ฆ ant coordinate (normalized to be between -1 and 1), as well as the relative position of the various colored cells and walls w.r.t. the ant, and the local coordinate of the ant within the current cell.
โข Reward: The agent gets a reward of 1 when the task is successfully completed, and 0 otherwise.
For both pretraining and post-training tasks, we use ๐บ = 4, ๐ = 4, ๐ = 1, and ๐ = 500. The set of pretraining tasks can be found in Table A2. We choose the post-training task to be 0 -1 -2 -3.
Fig. A1 displays the performance of linear probes predicting latent subgoals from the residual stream activations of a pretrained and subsequently frozen transformer in the gridworld-pinpad environment. These linear probes are obtained by following the procedure detailed in Section C.2. Importantly, these subgoals are not explicitly encoded in the data forced through the sequence model. Nonetheless, throughout training on a large corpus of unannotated goaldirected behavior the sequence model develops internal representations of the subgoals. These internal representations get linearly decodable deep in the network (with the accuracy jumping from 30% to about 50% roughly in the middle of the model). Interestingly, close to the output layer of the model (layer 6 in Fig. A1) the performance of linear probes deteriorates when plugged into a backbone trained beyond 100K steps.
In this section, we investigate the effect of various hyperparameter choices during sequence model training on the internal abstract action representation of the base autoregressive model. For all experiments, we use the gridworldpinpad environment. We measure the quality of abstract action representation following the procedure outlined in Section C.3, and by evaluating the compositional generalization of the obtained controllers on post-training tasks. For all experiments, we use the same hyperparameters as detailed in Section C.3, unless specified otherwise. The results are presented in Fig. A2.
Sequence model training steps. For all base autoregressive model depths (4,6 and 8), we notice that longer sequence model training generally leads to better internal abstract action representation, such that the controllers generalize better to the post-training task set.
For all base autoregressive model depths, we notice that weight decay during sequence model training is beneficial for internal representation. Interestingly, too much weight decay also degrades the representation, which points to a critical regularization trade-off that has been previously reported in foundation models [65].
Observation auxiliary loss. Next, we observe that some amount of auxiliary loss (i.e. training to predict the next observation as well as action) is beneficial to building internal abstract action representation. With very low coefficient for the auxiliary loss, we noticed that some models completely failed to learn the representation; however we suspect this behavior is an artifact of our particular environment rather than a general trend.
Expert suboptimality. Finally, we investigate the effect of the suboptimality of the demonstrations used during pretraining on the resulting abstract action representation. We achieve this by replacing the expert policy by an ๐-noisy one, where at every timestep, with probability ๐, a random (non terminating) action is taken. We see that the abstract action representation is robust against such suboptimality.
In Fig. A3, we analyze the temporal abstraction discovered in the gridworld-pinpad setting (c.f. Fig. 6 for the respective ant-pinpad results), by plotting the switching gate values ๐ฝ ๐ก against groundtruth abstract actions ๐ ๐ก . Similarly to the ant-pinpad setting, we find that the metacontroller essentially recovers the groundtruth abstract actions by the switch gate learning to behave in a quasi-binary fashion.
Figures A3 and6 reveal that the temporal abstractions discovered by the metacontroller during self-supervised learning reflect the ground truth structure of the underlying task. In particular, the switching unit aligns with compositional abstract subgoals governing the observed data in a quasi-discrete fashion.
In this section, we focus instead on the controller latent code ๐ง, and provide evidence that the latent space encodes the actual subgoal-seeking abstract actions that constitute the compositional task, in a context-agnostic manner. To achieve this, we focus on the ant-pinpad environment, and follow the following procedure:
-
For a handful of grid configurations, we first perform an unconditioned rollout, i.e. a rollout in the environment using the sequence model and the trained metacontroller while sampling the ๐ง from the Gaussian prior, instead of the variational distribution.
-
Next, for each object, we consider unconditioned rollout trajectories that correspond to the agent visiting that object (and nothing else), and collect the latent codes ๐ง that were active at the time of visit. We hypothesize these latent codes to encode the subgoal seeking abstract action towards the corresponding object.
-
Finally, we use those latent codes in different scenarios, and demonstrate that the same latent code’s subgoal seeking property generalizes to other situations.
Here, we investigate the ability of the latent code to generalize to new grid configurations and unseen switching times. The metacontroller is trained on successful, nearly-optimal trajectories where agents rarely demonstrate “backtracking” -behavior where an agent turns away from one object to seek another. Consequently, it is non-trivial whether a latent code injected mid-rollout can override the base model’s current trajectory. As shown in Fig. A4, injecting a “go to blue” latent code at timestep 30 causes the agent to immediately correct its course, even if it was previously moving toward a different object. This intervention increases the goal-reaching success rate from 23% in the uncontrolled baseline to 36%. This is significant, considering that the latent codes were generated for different configurations and are the result of a noisy sampling.
We further test whether these codes can force behavior that is explicitly absent from the training data. In the antpinpad environment, the agent is never trained to seek object 1 immediately after object 0 (c.f. Section A). By manually activating the latent code for object 1 after the agent reaches object 0, we find the success rate for this OOD transition rises from 10% (baseline) to 24%. Note that this also tests whether the same latent codes can generalize to a new position in the sequence, since they were collected from trajectories where the ant visited the corresponding object as the first object.
Ultimately, these results indicate that the metacontroller does not merely learn to segment time, but successfully discovers a compact, steering-capable representation of functional intent-providing the necessary ‘options’ for internal RL to perform efficient credit assignment in complex, hierarchical tasks.
We complement the main text Fig. 8 by showing the same plot in log scale in Fig. A5. First, we notice that our internal RL methods achieves the highest success rate at the beginning of training, indicating that noise injection in the residual stream is useful for exploration, compared to exploration done by sampling raw actions alone. At the same time, despite their high initial success rate, the baselines completely fail at exploiting the experience to reinforce their success. This indicates that proper temporal abstraction is beneficial not only for exploration, but also exploitation as well.
Here, we describe the details of training the sequence models (cf. Appendix D.1) as well as several variations of the controller (cf. Appendix D.2) acting on the then fixed sequence models to produce the results demonstrated in this paper. Moreover, in the final subsection we report the hyperparameters for each experiment.
As a prerequisite for the main experiments that all involve tampering with the activations of a pretrained model, the base sequence models ๐ ๐ are first trained to autoregressively predict next actions ๐ ๐ก and next observations ๐ ๐ก+1 given a sequence of observations ๐ 1:๐ก , on a set of meaningful expert trajectories. The details of the expert trajectory generation and sequence model training are given in the following.
Given an environment and a set of pretraining tasks, expert trajectories are a set of corresponding successful trajectories. The trajectories are not necessarily optimal.
For the gridworld-pinpad environment, we analytically solve (via dynamic programming) for the stochastic policy that solves the shortest path problem, and, at every timestep, replace the action by a random (non terminating) action with probability ๐. For all results unless explicitly specified, we chose ๐ = 0, but show the robustness against such noise in Section B.2.
For the ant-pinpad environment, we obtain the expert trajectories by training an RL agent. In order to train a single agent for the different task, we augment the observation in the following way: for each cell of the grid, the agent is given an additional 4-dimensional 1-hot vector, indicating one of the 4 cardinal directions the agent must move towards to follow the shortest path at the grid-level. Furthermore, an additional intrinsic reward corresponding to the dot product between the agent’s velocity and this direction is given. The agent is trained by PPO [66], see Table A3 for the hyperparameters used.
Note that the resulting expert trajectories are not always successful. The success rate for the pretraining tasks is at 0.8, and for the post-training tasks at 0.7. The success rate in Figure 8
Given a dataset ๐ท of expert trajectories, the sequence models are trained to maximize the log-likelihood of the data max ๐ log
(4) Switching the sign and reweighting the observation component with a coefficient ๐ yields the loss function presented in the main text, repeated below for convenience:
Table A4 summarizes the hyperparameter choices for training sequence models on the discrete gridworld and Table A5 those for the ant. For the sequence models, we use the hyperparameters specified in Table A15 for SSMs and those specified in Table A16 for transformers. We use SSMs for the ant-pinpad and transformers in the gridworldpinpad.
For both environments, we pretrain 10 such sequence models with different seeds.
Given a pretrained sequence model ๐ ๐ optimized to maximize Equation 5, we train linear probes to predict the latent subgoals governing a sequence at hand. More formally, given a sensori-action sequence (๐ 1:๐+1 , ๐ 1:๐ ) we train a linear probe ๐ ๐ โ โ ๐ ๐ ,๐ ๐ to predict the latent subgoal ๐ ๐ก from the residual activation ๐ ๐ก,๐ at layer ๐ โ 0, …, ๐ฟ at every timestep ๐ก. Here, ๐ ๐ and ๐ ๐ denote the residual stream Here, we illustrate the effect of a latent controller code implementing the abstract action “go to blue” in ant-pinpad, when forcing a switch at an arbitrary time. The three pairs display a trajectory without intervention by the metacontroller (left) vs. the one with the metacontroller running on a latent code corresponding to “go to blue” (right) respectively. The same controller latent code successfully steers the ant towards the desired color in different context, and regardless of the timing at which it is activated. Some trajectories demonstrate backtracking behavior when the control is applied. dimension and total number of subgoals in the dataset respectively and ๐ฟ is the number of layers of ๐ ๐ . The belief distribution over the subgoals at timestep ๐ก is parameterized as
The parameters ๐ ๐ are trained to minimize the crossentropy loss in how the latent codes ๐ง ๐ก are computed. Instead of sampling z๐ก from a normal distribution and then temporally integrating according to
๐ฝ ๐ก = 1 is forced for all ๐ก. Moreover, in these experiments, the ground truth information about abstract behaviour is injected via z๐ก . In particular, the expert trajectories (๐ 1:๐+1 , ๐
We adapted CompILE [17,18] as best as possible to our setting. On a high level, CompILE is very similar to our cotraining baselines: it is a latent variable model (albeit with a different set of latent variables) which takes a sequence of observations and output, for each timestep, a continuous latent variable ๐ง drawn from a Gaussian that then condition a policy trained to imitate the action in the trajectory. Similarly to us, it is a variational inference approach to discovering the abstract actions, except that it does not leverage the internal representation of a pretrained model. CompILE also infers the switching latent variables ๐ฝ, and requires a prior distribution over the switching rate and the maximum number ๐ of abstract actions (or segments) in all sequences.
To make things comparable, we adopt CompILE to our architecture by drop-in replacing the metacontroller by the CompILE module which generates the latent code ๐ง, while keeping everything else identical. In particular, the same sequence model architecture is used and internally controlled by ๐ง. Nevertheless, to remain close to the original architecture of CompILE, the module generating the latent code takes as input the raw input, instead of the residual stream activation. To compensate for the potential loss in expressivity, we use as the recurrent encoder of the module the same architecture as the first half of the sequence model. The parameters for the distribution of the latent variables are then generated by a one hidden layer MLPs with hidden dimension ๐ โ . Due to the difficulty of performing parallel inference when using the algorithm, the sequence model parameters ๐ are initialized to the pretraining value, such that training can be shortened.
Table A11 (resp. A12) shows the hyperparameter used for gridworld (resp. ant-pinpad).
For all RL experiments, we used an RL algorithm suitable for sparse, single final reward setting. The algorithm is related to the GRPO algorithm, except for the notion of group which is absent in our setting. Similarly to GRPO, we modify the standard Proximal Policy Optimization [PPO; 66] framework by replacing the learned value function (critic) with an empirical advantage estimation. , 1 -๐, 1 + ๐))A ๐ (10) where ๐ ๐ is the current policy and ๐ ๐ old is the previous policy, A ๐ก is the relative advantage of the trajectory ๐.
Relative advantage estimation. We adopt the critic-free approach to estimating the advantage. The relative advantage A ๐ measures how much better (or worse) a specific trajectory ๐ is compared to the average quality of the entire batch of size ๐ต. It is calculated by normalizing the reward ๐ (๐) relative to the batch’s mean R and standard deviation ๐ ๐ : R = 1
where ๐ฟ is a small constant (e.g., 10 -3 ) to ensure numerical stability and prevent division by zero.
For the raw action RL baseline, we simply use the pretrained sequence model without any metacontroller, and finetune it on the post-training task using the RL algorithms described in section C.5.2, in raw action space. Since for the raw action RL baseline, there is no unsupervised abstract action discovery phase, we instead do the RL over 3 seeds.
For all other baselines, we simply perform internal RL with the respective metacontrollers obtained during the unsupervised abstract action discovery, cf section C.4.
For each method, we scan over different learning rates (0.000003, 0.00001, 0.00003, 0.0001, 0.0003) and pick the learning rate and hyperparameter configuration from the unsupervised abstract action discovery with the best median RL performance over the 10 pretrained model and 3 seed.
The rate distortion curve is plotted after performing the unsupervised abstract action discovery with the same hyperparameters as described in section C.4, with the exception of using the Gumbel-Sigmoid trick for the switching units as derived in E.1 for gridworld-pinpad, as it resulted in a cleaner Pareto frontier. As the pretrained sequence models have each different Pareto frontier which would hide its structure, we pick one sequence model at random, and instead do the unsupervised abstract action training with 10 seeds on each of the different KL strength ๐ผ. We do this for our method, as well as the metacontroller cotraining baseline.
The sum of switching value is computed by hardthresholding the continuous switching value with ๐ฝ threshold = 0.5 as we do in the internal RL (cf algorithm 1), and taking the average sum over trajectories. We manually checked the forget patterns to label whether the switching patterns aligned with the subgoal change.
We parametrize the base model as an autoregressivelytrained multi-layer sequence model ๐ ๐ . The specific instantiations of ๐ ๐ detailed below utilize either standard transformer [1] or recurrent neural network (also commonly referred to as state-space model, SSM) layers. From the latter family ( [22,[67][68][69][70][71][72][73] and others), we choose the Hawk [22] due to its simplicity and computational efficiency.
For SSM-based sequence models, we employ a standard pre-normalization layer architecture. Inputs are normalized before being fed into the recurrent Hawk sequence mixing block [22], whose output is added back to the residual stream. This is followed by an MLP channel-mixing block that similarly applies normalization to its input before adding its output back to the residual stream.
For transformer-based models, we employ a standard prenormalization layer architecture. We first compute relative position embeddings to serve as attention biases. Inputs are then normalized and fed into the Multi-Head Attention sequence mixing block (incorporating these biases), whose output is added back to the residual stream. This is followed by an MLP channel-mixing block that applies normalization to its input before adding its output back to the residual stream.
where ฮฃ ๐ก is chosen to be diagonal for computational efficiency. The parameters for mean and variance are produced by the controller encoder ๐ ๐ก , ฮฃ ๐ก = ๐ enc ๐ ๐ก,๐ , โ ๐ก-1 , ๐ ๐ 1:๐,๐ .
Crucially, if this was the final parameterization of the approximate posterior ๐(๐ง ๐ก |๐ 1:๐,๐ ) it would not provide a handle on the factorization of the subgoals the agent composed in time when generating its behaviour. As a first step to obtain such a factorization, the metacontroller implements a switching unit producing the temporal integration rate
This integration rate is passed to the temporal integration unit which uses it to combine the latent code proposals z๐ก sparsely in time. In particular, given ๐ฝ ๐ก , แบ๐ก , and ๐ง ๐ก-1 , the updated latent code is given by the convex combination
Observe, that since the z๐ก are stochastically generated so is ๐ง ๐ก . Moreover, the ๐ฝ ๐ก which only rely on causal information and hence can be generated at test time provide a direct handle on the subgoals. When ๐ฝ ๐ก โ 1 a new subgoal z๐ก takes over while ๐ฝ ๐ก โ 0 indicates that the previous subgoal ๐ง ๐ก-1 remains a valid explanation for the intent of the agent. This latent controller code ๐ง ๐ก is then sent through the controller decoder. The controller decoder is a hypernetwork emitting the internal controller
As detailed above this linear controller is applied to the residual stream to control the backbone thereby impacting the predicted data log-likelihood computed at output of the sequence model. Crucially, the described mechanism allows the meta controller to act on extended timescales by maintaining the latent code ๐ง ๐ computed at timestep ๐ for some ๐ timesteps (by setting ๐ฝ ๐+1:๐+๐-1 = 0). Thereby, since the computation of the hypernetwork is deterministic, the same instantaneous controller ๐ ๐ can be applied for ๐ timesteps and corresponds to a temporally abstract action.
Since the residual activation ๐ ๐ก for a single layer does not necessarily contain all information about the raw input history, we use a recurrent policy. A simple 1-layer SSM as described in section D.1.1 is used. See Table A13 and A14 for more details on the architecture.
In order to improve stability during training, we make a continuous relaxation of the latent variable ๐ฝ sampled. In principle, this can be done with the Gumbel-Sigmoid trick, but in our experiments we simply used the probability as the latent variable. We modify the prior and variational distribution on ๐ง to be the continuous relaxation, i. This recovers the previous behavior when ๐ฝ ๐ก equals 0 or 1.
In the continuous case, it can be shown that the KL divergence is Further assumptions. In our experiments, we further modify the variational distribution on ๐ฝ such that ๐( ๐ฝ ๐ก | ๐ 1:๐ก , ๐ 1:๐ ) โ ๐( ๐ฝ ๐ก | ๐ 1:๐ก ), i.e., we drop the conditioning on the future. This is done such that during internal RL, the switching signal can be emitted causally, and eliminates the prior matching term for the switching module. This assumption was made in our experiments since we assumed the residual activation to be highly informative of when switches should occur, in the environments considered. In general however, we can relax this assumption by keeping the future conditioning, but distilling the switches to an unconditioned module.
As explained in Section C.5.1, internal RL learns a policy over the discovered abstract action space of ๐ง by treating the rest of the architecture as part of the environment, and applying reinforcement learning directly to ๐ง, with temporal abstraction. However, there are other ways to use the discovered abstract actions than the proposed internal RL. One perhaps more straightforward way to use the metacontroller, is to treat this policy as a noise-injecting submodule of the overall architecture which is still trained by reinforcement learning in raw action space, by backpropagating through the base autoregressive model, to the policy using e.g. the reparametrization trick. In this section, we analytically contrast these 2 options, discuss their respective advantages, and motivate why we believe the internal RL is interesting in general.
To simplify the analyses, we make a few assumptions:
โข We remain in the outcome-supervision setting: a single reward ๐ ๐ is provided at the last time step ๐.
โข The switching happens ๐ times, at (๐ก ๐ ) 1โค๐โค ๐ .
โข The abstract action policy has a fixed variance, i.e., it outputs ๐ง ๐ก = ๐(๐ ๐ก ) + ๐ ๐ก where ๐ ๐ก is the history of observations up to ๐ก, ๐ ๐ก โผ N (0, 1).
We now contrast the policy gradient update of the abstract action policy, between the 2 options discussed above.
Raw action space RL. Performing RL in raw action space, treating the abstract action policy as a model layer, would result in the following expected policy gradient update: where PG z (๐ง) = ๐ ๐ ๐ 0 is the policy gradient Monte Carlo estimator of a bandit problem. We see that the 2 expressions differ only in PG(๐ง). The tradeoffs are evident:
โข The expectation of PG raw is more structured than PG z .
In particular, its variance w.r.t. epsilon could even be 0 in the first, whereas it scales with the dimension of epsilon in the second.
โข However, the variance of PG raw scales with the number of timestep and with the raw action space dimension, since noise is accumulated at every timestep. On the other hand, PG z does not scale with anything (it is the variance of the return, i.e., ๐(1)). Therefore, if the abstract action discovery was successful such that a compact space of ๐ง was identified, with long-horizon abstract actions, the policy gradient estimator’s variance and corresponding credit assignment can be dramatically improved, especially for very long horizon tasks.
๐ก min(
architecture Design principles. The metacontroller is designed to act inside a frozen, autoregressive sequence model backbone. It does so by modulating the residual stream activations at some backbone layer via simple, internal controllers. Manipulating the residual stream allows the metacontroller to implement temporally abstract actions that turn the latent codes
E.1. Graphical model and ELBO derivation
๐ก โ ๐ก โ ๐ก โ