JADAI: Jointly Amortizing Adaptive Design and Bayesian Inference

Reading time: 42 minute
...

📝 Original Info

  • Title: JADAI: Jointly Amortizing Adaptive Design and Bayesian Inference
  • ArXiv ID: 2512.22999
  • Date: 2025-12-28
  • Authors: Niels Bracher, Lars Kühmichel, Desi R. Ivanova, Xavier Intes, Paul-Christian Bürkner, Stefan T. Radev

📝 Abstract

We consider problems of parameter estimation where design variables can be actively optimized to maximize information gain. To this end, we introduce JADAI, a framework that jointly amortizes Bayesian adaptive design and inference by training a policy, a history network, and an inference network end-to-end. The networks minimize a generic loss that aggregates incremental reductions in posterior error along experimental sequences. Inference networks are instantiated with diffusion-based posterior estimators that can approximate high-dimensional and multimodal posteriors at every experimental step. Across standard adaptive design benchmarks, JADAI achieves superior or competitive performance.

📄 Full Content

Many scientific and engineering questions concern unobserved properties of real-world systems: cosmological parameters governing large-scale structure (Hahn et al., 2024), biophysical parameters in mechanistic neural models (Gonc ¸alves et al., 2020), or epidemiological parameters driving disease dynamics (Radev et al., 2021). In many such settings, we can design simulators that, given hypothesized parameters θ, can generate synthetic observations x. However, inverting these simulators to recover the parameters from observations is oftentimes a computational ordeal.

Simulation-based inference (SBI) addresses this inverse problem by approximating the posterior over parameters from simulated pairs (θ, x), typically using a neural network (Cranmer et al., 2020;Zammit-Mangion et al., 2025;Deistler et al., 2025). Modern SBI employs generative models, such as flow-matching (Wildberger et al., 2023), diffusion (Sharrock et al., 2022), or consistency models (Schmitt et al., 2024) to represent complex, multimodal posteriors In many of these applications, however, we are not only handed data, but can also actively control how data is collected or how the system is perturbed. Experimental protocols, stimulus sequences, or public health interventions are often parameterized by design variables ξ that strongly influence how informative the resulting observations x are about θ. This experimental adaptability turns the inference problem into a two-fold question: (i) how to infer θ from a given dataset, and (ii) how to choose designs ξ that make inferences as informative as possible.

Bayesian experimental design (BED) formalizes an answer to the second question by selecting designs that maximize an expected utility, most commonly the expected information gain (EIG) about θ (Lindley, 1956). In Bayesian adaptive design (BAD), this becomes a sequential problem in which a designer (e.g., a neural network) proposes designs based on previous decisions and data acquisitions. This idea underlies recent works on deep Bayesian adaptive design, which learn global history-dependent policies directly from simulations (Foster et al., 2021;Ivanova et al., 2021;Blau et al., 2022;Huang et al., 2025), rather than solving a new optimization problem from scratch for every experiment. Despite this progress, BAD and SBI have largely evolved in parallel. Policy-based BAD approaches typically focus on proposing good designs, delegating posterior inference to slow (non-amortized) methods or restricting it to relatively simple parametric families in low-dimensional settings. Conversely, SBI methods usually work with fixed designs and do not directly optimize how observations are acquired. As a result, design and inference are typically treated as separate tasks (see Section 4 for related work).

In this paper, we introduce JADAI, a framework for Jointly Adaptive Design and Amortized Bayesian Inference. Our framework makes the following contributions:

• It jointly amortizes adaptive experimental design and posterior inference via a general utility that aggregates incremental improvements end-to-end.

• It incorporates modern diffusion-based generative mod- Y 8 3 q U 9 W a m X C s V y o X y x m 6 + W 0 o X P 0 y Z t 0 Q 6 2 u k 9 V O q V z 1 O F i h q / 0 R u / W r f V o P V n P P 6 F W J s 3 Z o D / H e v k G D D y X n w = = < / l a t e x i t > ω < l a t e x i t s h a 1 _ b a s e 6 4 = " 9 A J k V h C w C K n 7 E K + m 7 g 6 B 2 o p w Z 5 8 = " > A A A C 3 n i c h V E 7 S w N B E J 6 c 7 / i K W t o c B s E q b I J G 0 w V 8 Y C M o m E Q w Q f c u a z x y L / Y 2 g R j S 2 o m t n a 3 + J P 0 t F n 6 7 X g S L 4 B 5 7 M / P N z L f z c G L f S x R j H x l r a n p m d m 5 + I b u 4 t L y y m l t b r y d R T 7 q i 5 k Z + J K 8 c n g j f C 0 V N e c o X V 7 E U P H B 8 0 X C 6 h 9 r f 6 A u Z e F F 4 q Q a x a A W 8 E 3 p 3 n s s V o J v c R t O J / H Y y C C C G T X U v F B / d 5 P K s U C n t 7 V a Y z Q r M H C j a L j G 7 m C J 5 S s 9 5 l P u k J r U p I p d 6 F J C g k B R 0 n z g l + K 6 p S I x i Y C 0 a A p P → q ω ( | )

< l a t e x i t s h a 1 _ b a s e 6 4 = " d V e i 9 z h k

C Z t l y r B K k i g 5 6 E V a + P e j D m 8 t + h z j r X l V K 5 W q p e V I q 1 S j r w L O 3 S H h 1 g q o d U o z N q o A 4 1 z R d 6 p T f j 0 h g b j 8 b T N 9 X I p D k 7 9 G s Z z 1 + m a 5 I B < / l a t e x i t > Designer < l a t e x i t s h a 1 _ b a s e 6 4 = " I I y C S g 7 y K y u f k 6

2 o m t n a 3 + J P 0 t F n 6 7 X g S L 4 B 5 7 M / P N z L f z c G L f S x R j H x l r a n p m d m 5 + I b u 4 t L y y m l t b r y d R T 7 q i 5 k Z + J K 8 c n g j f C 0 V N e c o X V 7 E U P H B 8 0 X C 6 h 9 r f 6 A u Z e F F 4 q Q a x a A W 8 E 3 p 3 n s s

x / e C E S J a N w P + 3 E 4 i R w m q E 8 l Z 6 T E q o X L N N W M j D P 6 3 a s 5 els, providing amortized high-dimensional and multimodal posteriors at every experimental step.

• It achieves superior or competitive performance across a range of adaptive design benchmarks.

For the purpose of our discussion, a simulator generates observables x ∈ X as a function of unknown parameters θ ∈ Θ, design variables ξ ∈ Ξ and random program states z ∈ Z:

The forward problem in ( 1) is typically well-understood through a mathematical model. The inverse problem, however, is typically much harder, and forms the crux of Bayesian inference: for a given design ξ, estimate the unknowns θ from observables x via the posterior distribution:

However, when working with complex simulators, the likelihood p(x | θ, ξ) needed for proper statistical inference cannot be directly evaluated. In these cases, a posterior estimator q(θ | x; ξ) needs to be constructed from simulations of ( 1) alone. This is the gist of simulation-based inference (SBI; Cranmer et al., 2020;Diggle & Gratton, 1984).

In neural SBI, synthetic pairs (θ n , x n ) obtained via (1) are used to train a (generative) neural network. The network can approximate either the intractable posterior, likelihood, or both. The network is then applied to unlabeled real observations x o , essentially solving a sim2real problem.

Neural SBI is compatible with any off-the-shelf generative model, such as normalizing flows (Rezende & Mohamed, 2015), flow matching (Lipman et al., 2023), or diffusion models (Song & Ermon, 2019). In this work, we focus specifically on amortized methods, which train a single (global) posterior estimator q(θ | x; ξ) that remains valid for any x o ∼ p * (x o ) (see Figure 1, left).

SBI methods are typically developed and applied for a fixed design ξ. Ideally, we want to choose the design that maximizes the amount of information we expect to gain from our subsequent observations. This is the goal of Bayesian experimental design (BED; Lindley, 1956;Chaloner & Verdinelli, 1995;Rainforth et al., 2024).

Realized information gain. Suppose we run an experiment with design ξ and obtain an actual observation x o . The realized information gain (IG) is defined as the reduction in entropy from the prior p(θ) to the posterior p(θ | x o ; ξ):

where H[p] denotes the Shannon entropy of p. Crucially, the IG can be evaluated only after observing the outcome x o and fitting a posterior p(θ | x o ; ξ).

Expected information gain. Before running the experiment, however, the outcome x is unknown. Thus, the realized IG (3) cannot be used to select a design ξ. Instead, BED aims to select the design that maximizes the expected information gain (EIG), where the expectation is taken with respect to the prior predictive distribution of the outcome,

The second expression shows that EIG(ξ) is the mutual information between θ and x under design ξ. Maximizing the EIG, therefore, selects designs that, on average, are expected to yield the most informative observations about θ.

BED is particularly attractive when experiments entail a sequence of design decisions ξ 1 , . . . , ξ T , where each decision can make use of previous observations (Rainforth et al., 2024).

k=1 denote the raw experimental history up to step t. At each such step, we can update our beliefs p(θ | H t-1 ) based on all information gathered so far, and then choose the next design ξ t by maximizing the incremental expected information gain:

This expression is simply the standard EIG ( 5), but evaluated using the current posterior as the prior for the next step. Thus, BAD selects designs by maximizing (6) at each step t, with H 0 = ∅ (see Figure 1, middle).

Foster et al. ( 2021) introduced a policy network π ϕ that predicts the next design from history, π ϕ (H t-1 ) = ξ t , and showed that the decision process can be amortized over all T steps by maximizing the total EIG (TEIG):

where all expectations are under p(θ) p(H T | θ, π ϕ ). In the next section, we extend this framework to the typical SBI setting, and jointly amortize both the design policy and posterior inference across experimental time.

Our framework builds on the SBI and BAD setting from Section 2 (see also Figure 1, right). To represent the history in a fixed-dimensional form, we introduce a summary network η ω that maps the history sequence

with h 0 = 0 corresponding to the empty history state. As in amortized policy-based BAD (Foster et al., 2021;Ivanova et al., 2021), a deterministic policy network π ϕ selects the next design based on the current summary, ξ t = π ϕ (h t-1 ) at each step t ≥ 1.

Inference is performed by a neural posterior estimator q ψ (θ | h t ), which takes the summary h t as input and approximates the intractable posterior p(θ | h t ). For a simulated pair (θ, h t ) we define the per-step training loss

where L post is the posterior loss induced by the posterior estimator (e.g., score matching or flow matching; see Section A.1 and Section A.2 for details). Although q ψ is the only density estimator, ℓ t depends on all parameters (ϕ, ω, ψ) through the summary

) and the designs ξ k = π ϕ (h k-1 ). We treat ℓ t as a shared training signal to update the policy π ϕ , the summary network η ω , and the posterior estimator q ψ . For notational convenience, we will usually suppress the explicit dependence on (ϕ, ω, ψ) and write ℓ t instead of ℓ t (θ, h t ; ϕ, ω, ψ).

To obtain a tractable training objective, we use the Barber-Agakov variational lower bound (Barber & Agakov, 2003;Foster et al., 2019;Blau et al., 2023) and replace the true posterior p(θ | h T ) by a neural estimator q ψ (θ | h T ) in (9):

with equality if and only if

For intuition, we first consider the case when q ψ is a normalized density model, i.e., a normalizing flow with ℓ flow t (θ, h t ) :=log q ψ (θ | h t ) as in Blau et al. (2023).

Then the logarithm in (12) admits the decomposition

Hence, for a normalized density model, the variational TEIG bound can be written as a telescoping sum of per-step differences in information content. However, the use of normalizing flows can be restrictive in practice (Chen et al., 2025), and we want our framework to generalize beyond inference models with exact log-density computation.

Thus, we suggest using the generic quantity ℓ t-1ℓ t as a proxy for the incremental information gained from step t -1 to t, even when q ψ is approximated with an implicit model for which the normalized log-posterior is not directly available. For example, a diffusion model formulation of ( 14) entails the difference in posterior scores as a summand:

where τ denotes diffusion time (see Section A.1). In that case, the Barber-Agakov interpretation (Barber & Agakov, 2003) is no longer the same, but the resulting objective still encourages trajectories along which the posterior loss decreases, as demonstrated by our experiments in Section 5.

This motivates the definition of the general scalar utility

where ℓ * t = detach(ℓ t ) denotes the same loss value with gradients stopped. Our final training objective minimizes the negative expected utility

using mini-batch gradient descent on L. Since the sum in ( 16) reduces to u T (θ, h 0:T ) = -ℓ T (θ, h T ), maximizing u T is equivalent to minimizing the final posterior loss. However, because the baseline terms ℓ * t-1 are treated as constants during backpropagation, telescoping does not apply to gradients. For any parameter block, say ϕ, we obtain

Thus, the gradient direction aggregates contributions from all intermediate losses ℓ t (θ, h t ) and pushes the networks (π ϕ , η ω , q ψ ) towards improving the posterior approximation at every step along the experimental history sequence for t ≥ 1 and an initial approximation to the prior at t = 0.

In practice, the expectation in ( 17) is approximated via Monte Carlo. For each training instance in the mini-batch, we sample a parameter θ ∼ p(θ), initialize the empty history state h 0 , and evaluate the initial loss ℓ 0 (θ, h 0 ) so that q ψ learns to approximate the prior at t = 0. We then iteratively generate the history rollout to eventually obtain -u T (θ) and update (ϕ, ω, ψ) using its averaged gradient.

The full training procedure is outlined in Algorithm 1.

At test time, we freeze (ϕ, ω, ψ) and use the learned triple (π ϕ , η ω , q ψ ) by mirroring the rollout described above, but without loss evaluation or gradient updates. We initialize the empty history sequence and summary state h 0 = 0 and, for t = 1, . . . , T , repeatedly predict the next design

) and make a new observation x o t ∼ Experiment(ξ t ) by running the experiment. At any step, we could query the approximate posterior q ψ (θ | h t ); in the diffusion case, this corresponds to conditional sampling described in Section A.1 with h t as conditioning input. After the final step T , we obtain our best posterior approximation q ψ (θ | h T ) given all measurements collected under the learned design policy.

Input: prior p(θ); simulator Sim(θ, ξ); networks: summary ηω, policy π ϕ , posterior q ψ ; max horizon T ; schedules R(n; T ), ρ(n); window W Output: trained parameters (ϕ, ω, ψ) 1 for training iteration n = 1, 2, . . . do (prior mix-in or policy)

Avoiding nested backpropagation through time. In the naive implementation of ( 17), each design is generated from the previous summary, ξ t = π ϕ (h t-1 ), and each summary is computed from all past tokens,

). As a result, the forward graph for ξ t contains an increasingly long chain over time: ξ 1 depends only on h 0 , ξ 2 depends on (h 0 , ξ 1 , x 1 , h 1 ), ξ 3 depends on (h 0 , ξ 1 , x 1 , h 1 , ξ 2 , x 2 , h 2 ), and so on. When backpropagating from a final loss ℓ T (θ, h T ), gradients therefore have to traverse both the rollout over time and, for each token (ξ t , x t ), an internally time-unrolled subgraph through earlier summaries and policy evaluations. In effect, this yields a nested form of backpropagation through time (BPTT) whose depth grows on the order of 1 + 2 + • • • + T , and tightly couples the history representation to the policy via a cyclic gradient path (history state → policy → history state).

To avoid this feedback loop, we generate designs from a detached history state, ξ t = π ϕ (h * t-1 ). Forward passes are unchanged, but gradients from the posterior losses can no longer flow back into the history through the policy inputs. In other words, the summary network η ω is trained only via its influence on the posterior losses ℓ t (θ, h t ), while the policy π ϕ is trained via the effect of its designs on future losses, through the tokens (ξ t , x t ) that enter the summaries. Conceptually, this enforces the role of h t as a learned summary statistic that is optimized to be maximally informative about the design-aware posterior p(θ | h t ) (Radev et al., 2020;Chen et al., 2023). The policy processes these approximately sufficient summaries to propose useful designs. This removes the nested BPTT structure and leaves a single, simpler gradient path through the history, similar to standard BPTT in recurrent models. On top of this, we can optionally limit gradient propagation through time by detaching tokens older than a fixed window size W , so that losses at step t only backpropagate through the most recent W (ξ, x)pairs. This is analogous to truncated BPTT in RNNs and provides a simple mechanism to keep memory and compute manageable for longer rollouts or larger architectures (see Section 5.1 and Appendix B).

Scheduling and sampling the rollout length. Because the policy, summary, and posterior networks are mutually dependent, very long rollouts early in training can hinder learning: even the t = 0 posterior q ψ (θ | h 0 ) must first learn to match the prior p(θ) before later decisions become meaningful. We therefore use a curriculum on the rollout length via a monotone schedule on the current maximum rollout length R(n) ≤ T at training iteration n that gradually increases towards T , and then sample the actual rollout length as r ∼ U 1, R(n) . For that iteration, the corresponding utility u r (16) gets truncated at t = r. Sampling r rather than always using R(n) ensures that, at any point in training, the networks see a mixture of short and long decision sequences.

Exploration via design prior mix-in. Complementary to the rollout-length curriculum, we also use a curriculum on how designs are chosen during training. Early in training, the policy network tends to propose designs in a narrow region of the design space, so relying on ξ t = π ϕ (h t-1 ) alone would yield little diversity in the observations. To expose the posterior and summary networks to more varied data initially, we treat the design prior p(ξ) as a generic source of exploratory designs and combine it with the learned policy at every rollout step according to

where ρ n ∈ [0, 1] is a scheduled exploration probability that is typically decreased over the course of training. This framework covers several useful regimes: an initial phase of pure prior sampling (ρ n = 1), a fully policy-driven regime (ρ n = 0), and annealed schedules that gradually shift from prior-driven exploration to purely policy-based design as ρ n decreases from 1 to 0. Concrete choices of ρ n for each experiment are described in Section 5.

Amortized simulation-based inference. Amortized SBI methods typically train a global posterior functional, x → q ψ (θ | x) for a fixed design ξ. The functional can be realized by any generative model, such as normalizing flows (Ardizzone et al., 2018), flow matching (Wildberger et al., 2023), diffusion (Sharrock et al., 2022), or consistency models (Schmitt et al., 2024). A common architectural choice is to separate the model into a summary network η ω (x), which embeds observation sequences into a fixed-dimensional representation, and an inference network (i.e., a generative backbone) which can sample from q ψ (θ | η ω (x)) (Radev et al., 2020;Chen et al., 2021;2023). Our method uses design-aware summary networks combined with flexible diffusion-based inference backbones (Arruda et al., 2025).

Variational Bayesian experimental design. Neuralbased variational formulations replace unknown densities (posterior, marginal, and/or likelihood) with flexible parametric approximations (Foster et al., 2019). When the posterior p(θ | x, ξ) is replaced by a variational approximation q ϕ (θ | x, ξ) in ( 5), this yields the Barber-Agakov lower bound on the EIG (Barber & Agakov, 2003). More recent work has employed flexible neural models, such as conditional normalizing flows, to parameterize the posterior (Orozco et al., 2024;Dong et al., 2025) or likelihood (Zaballa & Hui, 2025). In all these approaches, however, these variational approximations are used primarily as a surrogate for efficient, differentiable EIG estimation in static (i.e. non-adaptive) design problems, rather than as final amortized inferential objects.

Bayesian adaptive design. Traditionally, BAD has been formulated as a greedy two-step sequential procedure. At each experiment iteration t, one first optimizes the incremental, one-step ahead EIG (6) using gradient-free (von Kügelgen et al., 2019;Hamada et al., 2001;Price et al., 2018) or, more recently, gradient-based surrogates (Foster et al., 2020;Kleinegesse & Gutmann, 2020). Most recently, Iollo et al. (2025) extended this gradient-based strategy to high-dimensional tasks using diffusion models, leveraging a pooled posterior proxy to estimate the gradients of the EIG. After observing the experimental outcome, a separate Bayesian inference step is performed, typically using (asymptotically) exact methods such as MCMC or SMC (Drovandi et al., 2014;Kuck et al., 2006;Vincent & Rainforth, 2017), and resorting to simulation-based methods only when the likelihood is intractable (Huan & Marzouk, 2013;Lintusaari et al., 2017;Sisson et al., 2018). This pipeline blueprint separates design optimization from posterior inference, but has to solve the inference problem from scratch at each decision step.

Amortized Bayesian adaptive design. The idea of fully amortizing the adaptive design process is to avoid intermediate posterior calculations by directly mapping past experimental data to future design decisions. Foster et al. (2021) were the first to derive the TEIG objective and, using a lower bound on its likelihood-based form (8), train an amortized Deep Adaptive Design (DAD) policy network. The idea has subsequently been extended to differentiable implicit models (Ivanova et al., 2021), to objectives that directly

< l a t e x i t s h a 1 _ b a s e 6 4 = " I a S 5 R 9 i I n

< l a t e x i t s h a 1 _ b a s e 6 4 = " j J Q Z U j 0 7 r + l t 6 j M w q g j k T b j 4

r a Y 9 a I / a 0 z d V S y U 5 e / R r a c 9 f e I 2 T E Q = = < / l a t e x i t > ω11 < l a t e x i t s h a 1 _ b a s e 6 4 = " c I I a q M 7 Q

< l a t e x i t s h a 1 _ b a s e 6 4 = " c I I a q M 7 Q W J Q x O j P e 1 D a / / 3 j 0

Rollout process for Location Finding. Panels: posterior samples and chosen designs over time t, with crosses marking the true source locations. The second posterior mode is typically uncovered around t = 10 measurements. Bottom right: corner plot of the learned posterior over the two sources at t = 10 shows nearly identical densities at (θ11, θ12) and (θ21, θ22), indicating that the model correctly captures exchangeability of the two modes, that is,

target downstream decision-making utilities (Huang et al., 2024), to semi-amortized settings that introduce local policy updates (Hedman et al., 2025), and to reinforcement learning (RL) based approaches suitable for non-differentiable simulators (Blau et al., 2022;Lim et al., 2022). All of these approaches optimize the design policy in isolation, often leaving accurate posterior estimation as a post-hoc task.

Unified amortized design and inference. A recent wave of methods aims to jointly amortize adaptive design and Bayesian inference. Three closely related approaches: RL-sCEE (Blau et al., 2023), vsOED (Shen et al., 2025), and ALINE (Huang et al., 2025) optimize a BA-style variational lower bounds on the TEIG (12) by explicitly casting the design problem within an RL framework. All three methods consequently rely on high-variance REINFORCE estimators (in the case of ALINE), or actor-critic algorithms that require training of additional value networks (in the case of RL-sCEE and vsOED).

Furthermore, their dependence on explicit density estimators of the posterior q ϕ necessitates estimators with tractable likelihoods, restricting them to architectures such as normalizing flows or the much less expressive GMMs. In contrast, JADAI frames the problem as optimization over sampled rollouts, demonstrating that the heavy RL machinery is not strictly necessary for effective amortized design and inference. Finally, unlike prior methods, JADAI incorporates implicit generative models (e.g., diffusion models) that afford scalable and flexible inference.

We evaluate our method on three benchmarks that illustrate a progression in both posterior and policy complexity. The first two, Location Finding (LF) and Constant Elasticity of Substitution (CES), are standard benchmarks: LF requires only a simple policy but yields a multimodal posterior, whereas CES typically leads to a simple, approximately unimodal posterior but requires a more complex policy. Recently, Iollo et al. (2025) proposed the MNIST Image Discovery (ID) task, which combines both challenges in a highdimensional observation space and requires a sophisticated policy together with a flexible multimodal posterior.

For LF and CES, we assess policy quality using the sequential prior contrastive estimation (sPCE) (Foster et al., 2021) lower bound on the total expected information gain, while for ID we report the Structural Similarity Index Measure (SSIM) (Wang et al., 2004) and the normalized rootmean-square error (NRMSE) (see Appendix C for details on experiments and metrics).

We first consider the Location Finding benchmark of Foster et al. (2021), where the goal is to infer the locations θ of K = 1 or K = 2 signal-emitting sources (depending on the experimental setting) from noisy measurements of their summed intensity x at adaptively chosen measurement positions ξ (experimental details in Section C.2). Because the optimal policy is relatively simple, we pre-train the summary and posterior networks under a random design policy before joint training, as discussed in Section 3.4.

Qualitatively, the learned policy balances exploration and exploitation: during the initial rollout steps, when posterior uncertainty is high, it explores the design space broadly; as uncertainty decreases, it concentrates measurements in regions of high posterior density, placing most posterior mass close to the true source locations while continuing to explore until both sources (i.e., K = 2) have been identified (see Figure 2 for a typical rollout). Since the policy only observes the current summary state h t , this behavior indicates that h t encodes which regions have already been probed and where posterior mass is concentrated. Furthermore, the corner plot in Figure 2 shows nearly identical densities at (θ 11 , θ 12 ) and (θ 21 , θ 22 ), confirming that q ψ (θ | h t ) learns the full joint posterior and respects exchangeability of the two sources.

Quantitatively, we evaluate the policy using the sPCE lower bound on the total EIG. Our policies are competitive across all settings and outperform prior approaches for both K = 1 and K = 2 sources whenever T > 10 (Table 1). Moreover, training with a longer terminal horizon further improves performance at shorter evaluation horizons: policies trained Table 1. sPCE lower bound on total EIG (↑) for Location Finding (LF) and constant elasticity of substitution (CES) benchmarks. For LF, our posterior-based policies (u10, u20, u30) exceed prior baselines for all cases where T > 10. For CES, our method outperforms all existing baselines at the standard evaluation horizon T = 10. Policies trained with the longest terminal horizon perform best, also at intermediate rollout lengths, e.g., u30 vs. u20 evaluated at T = 20. L is the number of contrastive samples.

10T 2K 5 • 10 5 L 20T 2K 5 • 10 5 L 30T 2K 10 6 L 30T 1K 10 6 L 10T 3K 10 7 L Random 4.79 ± 0.04 7.00 ± 0.03 8.30 ± 0.04 5.17 ± 0.05 9.05 ± 0.26 SG-BOED (Foster et al., 2020) 5.55 ± 0.03 7.70 ± 0.03 8.84 ± 0.04 5.25 ± 0.22 9.40 ± 0.27 iDAD (Ivanova et al., 2021) 7.75 ± 0.04 10.08 ± 0.03 —DAD (Foster et al., 2021) 7 with T = 30 achieve higher sPCE than those trained with T = 20 when both are evaluated at T = 20. Since the same summary and policy networks are applied at every time step, optimizing per-step posterior losses up to T = 30 includes all losses for t ≤ 20 and additionally trains the networks on later, typically more concentrated posteriors. This extra training signal can refine how the summary network and the policy respond to similar configurations that already occur earlier in the rollout, leading to summary representations that generalize better across rollout lengths.

Additional ablation results (Appendix B) and examples for a longer terminal horizon (T = 30, Figure 5) are presented in the appendix.

As a more challenging design problem, we consider the Constant Elasticity of Substitution (CES) (Arrow et al., 1961;Foster et al., 2019) benchmark next, where an agent rates the difference in subjective utility x between a pair of two baskets ξ each with K = 3 goods, and the goal is to infer a five-dimensional preference parameter θ. Informative designs lie in a narrow “sweet spot” between nearly identical baskets (indifference) and very different baskets, where noise and sigmoid saturation dominate (Foster et al., 2019).

In practice, this makes random designs largely uninformative: random-policy pretraining led to unstable training or collapsed weights, so for CES we train summary, policy, and posterior jointly from the beginning (see Section C.3).

Our method outperforms all state-of-the-art approaches at the commonly used evaluation horizon T = 10 (see Table 1).

As in Location Finding, training with a longer terminal horizon yields additional gains: policies trained with T = 30 achieve slightly higher sPCE when evaluated at T = 10 than policies trained directly with T = 10.

Finally, we evaluate our method on the high-dimensional MNIST Image Discovery task introduced by “CoDiff” (Iollo et al., 2025). At each step, the policy selects a spatial location ξ, and the simulator reveals a local measurement patch x. The downstream task is to reconstruct the full digit image θ from a sequence of such measurements. We follow CoDiff’s simulator implementation but also consider variants with additive measurement noise (σ > 0) that remove useful signal outside the measurement mask (see Section C.4).

During training, we follow the mixed-policy scheme from Section 3.4, starting with mostly random designs and gradually annealing the probability ρ n of random actions to zero so that, over time, rollouts are generated entirely by the learned policy. Alongside the diffusion-based posterior estimator, we also train a flow matching variant with the same architecture (see Section A.2 for details), demonstrating that our framework can incorporate score-and flow matchingbased objectives equally well.

Intuitively, an effective policy should first gather information that disambiguates the digit class and then refine the digit’s shape. Qualitatively, our learned policies exhibit this behavior: early measurements are placed on nonoverlapping, class-discriminative regions, while later measurements refine local details (see Figure 3, and Figure 7 for additional examples). The posterior typically converges to the correct digit shape in fewer than T = 6 measurements.

Quantitatively, we average SSIM and NRMSE results over 30 posterior samples for the whole validation split at each rollout step (Figure 4), achieving the best results across all noise levels (Table 2). Both metrics improve rapidly during the first few measurements, indicating that most information is gained early, with later steps primarily refining the reconstruction. Notably, under a random policy, CoDiff and our posterior network achieve similar SSIM values, indicating that the performance gap is primarily due to the policy rather than differences in sampling or network architectures.

We introduce JADAI, a new framework for jointly amortizing adaptive experimental design and posterior inference via an incremental posterior loss as a proxy for the classical TEIG. Our method enables posterior estimation at any step of the sequential design process, rather than only after a fixed horizon, and thus connects naturally to active dataacquisition use cases. At test time, experts may override or modify the designs proposed by the learned policy while being informed by intermediate approximate posteriors.

In the default setting (without user intervention), full rollouts run in milliseconds, which is comparable to recent methods (Huang et al., 2025) on low-dimensional tasks and roughly an order of magnitude faster on our high-dimensional benchmark. Posterior sampling, however, remains a bottleneck: generating 10,000 samples takes a few seconds in the highdimensional case, making sub-second deployment challenging. A promising direction is to distill the posterior estimator post hoc into a faster surrogate.

Both qualitatively and quantitatively, JADAI typically matches or improves upon prior work, particularly on highdimensional inference problems, while approximating multimodal posteriors and maintaining effective policies for more complex design choices such as CES. A natural direction for future work is to investigate the limits of this approach as the design space becomes increasingly complex, for instance, by considering higher-dimensional designs like spatial patterns or time-series stimuli.

Although JADAI is applicable when the likelihood is not available in closed form, our experiments rely on differentiating through the simulator. As differentiable simulators in autodiff frameworks become more common, this setting is increasingly relevant; however, extending JADAI to purely black-box simulators remains an important direction and will likely require gradient-free design optimization.

In this work, we used separate networks for the policy, summary, and posterior. By contrast, the success of ALINE (Huang et al., 2025) and related work (Huang et al., 2024;Zhang et al., 2025;Chang et al., 2024) stems in part from a well-chosen transformer backbone with multiple taskspecific heads. Thus, a natural extension of JADAI would be to keep the posterior network separate but let the policy and summary networks share a transformer encoder.

In the following, we provide a brief overview of our diffusion model (Ho et al., 2020;Song & Ermon, 2019;Kingma & Gao, 2023) used to approximate the posterior p(θ | •) ≈ q ψ (θ | •). More details on diffusion models in a general setting can be found in (Karras et al., 2022) and, for simulation-based inference, in (Arruda et al., 2025).

A diffusion model learns how to gradually denoise a sample from a base distribution z 1 ∼ p(z 1 ) = N (0, I), typically a standard Gaussian, towards the target data distribution θ ≡ z 0 ∼ p(z 0 | •). At the core of this learning process is the forward corruption process

where f (τ ) and g(τ ) are drift and diffusion coefficients respectively and dW defines a Wiener process. Starting from τ = 0, this forward process gradually adds Gaussian noise to a sample from the target distribution until it approximately follows the base distribution at τ = 1.

This construction allows computing the marginal densities p(z τ | z 0 ) analytically for every 0 < τ < 1:

The reverse SDE has the form

where f is a new drift term depending on the score ∇ zτ log p(z τ | z 0 ) of the conditional distribution from the forward process (21). A neural network is trained to approximate that score via a weighted score-matching objective:

where x denotes an optional condition variable, w τ a diffusion time-dependent weighting and z τ computed as in ( 21) with coefficients defined in the following.

The coefficients in the marginal density (21) define a noise schedule. They control how much noise is added at each step τ ∈ [0, 1] and are related to the SDE coefficients in (20) via f (τ ) = α ′ τ /α τ and g(τ

In our experiments, we chose a variance-preserving schedule such that the relation between these two is given by 1 = α 2 τ + σ 2 τ . Although diffusion time τ is sampled uniformly in (23) and controls both coefficients, the noise schedule is often parameterized in terms of the log signal-to-noise ratio λ(τ ) = log(α 2 τ /σ 2 τ ), which allows shifting emphasis towards specific regions of the noise spectrum. For all experiments, we used a cosine schedule λ(τ ) = -2 log tan πτ 2 that places more probability mass near intermediate SNR levels (λ τ ≈ 0) than the original linear schedule.

Instead of directly predicting the conditional score, one can predict the noise ϵ or an interpolation between data and noise v at each time step τ and replace the score s in (23) accordingly. The relation between noise and the score is:

so predicting ϵ is equivalent to predicting the score up to a known scaling factor. The noise target is simply ϵ ∼ N (0, I) while the target for v-prediction is defined as the interpolation between data and noise, which in the variance-preserving case is:

In our case, the network is parameterized for v-prediction, and its outputs are converted to the noise domain via (25) so that the chosen noise and weighting schedules remain unchanged by the parameterization.

After training, we draw approximate posterior samples θ ∼ q ψ (θ | h t ) by starting from a base sample z 1 ∼ N (0, I) and solving the probability-flow ODE associated with the reverse SDE in Eq. ( 20). Concretely, we use the deterministic ODE

where s ψ (z τ , h t , τ ) denotes the learned conditional score. In practice, the network predicts the velocity v τ , which is converted to a score estimate by using the relations between velocity, noise, and score from ( 25) and ( 24). We integrate this ODE numerically from τ = 1 to τ = 0 using an explicit Euler solver with N = 1000 equidistant steps.

Posterior diffusion loss in our setting. In our SBI setting, we use a diffusion model to approximate the conditional distribution p(θ | h t ) at each step t, with θ ≡ z 0 and the summary h t ≡ x playing the role of the conditioning variable. For each training iteration, we sample τ ∼ U(0, 1) and ϵ ∼ N (0, I) once, construct z τ from ( 21) and reuse the same (τ, ϵ) at all rollout steps along the trajectory. The network takes (z τ , h t , τ ) as input and predicts v τ , which we convert to the noise domain to obtain ϵ ψ using (25). For a fixed sampled pair (τ, ϵ), the resulting per-step diffusion posterior loss is

where h t depends on (ϕ, ω) implicitly through the policy and summary networks,

) and ξ k = π ϕ (h k-1 ). In the notation of the main text, where the generic per-step posterior loss is ℓ t (θ, h t ) := L post (θ, h t ; ϕ, ω, ψ), we simply set

for the diffusion-based posterior estimator. Plugging this choice of ℓ t into the utility u T in Eq. ( 16) and the global objective L in Eq. ( 17) yields the diffusion-specific training objective used in our experiments:

In practice, the joint expectation in L diff is approximated via Monte Carlo over quadruples (θ, h 0:T , τ, ϵ). Each training iteration draws a minibatch of parameters θ ∼ p(θ), simulates the corresponding rollout summary states h 0:T ∼ p(h 0:T | θ, π ϕ , η ω ), and, for each rollout in the minibatch, samples a single pair (τ, ϵ) with τ ∼ U(0, 1) and ϵ ∼ N (0, I). The same (τ, ϵ) is reused across all steps t of that trajectory when evaluating L diff post (θ, h t ; ϕ, ω, ψ, τ, ϵ) and aggregating the utility u diff T in Eq. ( 16).

Flow matching provides an alternative to score-based diffusion models by instead of simulating a stochastic forward process, one specifies interpolation paths {z τ } τ ∈[0,1] and learns the associated velocity field of the probability-flow ODE directly (Liu et al., 2022;Lipman et al., 2023). Conditional applications in the SBI setting are discussed in Wildberger et al. (2023) and in (Arruda et al., 2025).

As in the diffusion setup above, let θ ≡ z 0 ∼ p(θ) denote a sample from the target distribution and let z 1 ≡ ϵ ∼ N (0, I) denote Gaussian noise. A simple linear interpolation is

with the flow-matching schedule α τ = 1τ and σ τ = τ . The associated probability-flow ODE has the velocity field

which is parameterized by a neural network v ψ (z τ , h t , τ ) ≈ v(z τ , h t , τ ) and is constant along the path and does not depend on τ explicitly.

Sampling from the approximate posterior θ ∼ q ψ (θ | h t ) can be done by solving the probability-flow ODE

starting from a base sample z 1 = ϵ ∼ N (0, I) at τ = 1 and integrating backwards to τ = 0 with an explicit Euler solver and N = 1000 steps:

Posterior flow-matching loss in our setting. As in the diffusion case (see Section A.1), we use the flow-matching objective as the per-step posterior loss for the generic training objective in Section 3. For a simulated pair (θ, h t ) at design step t, and a fixed sampled pair (τ, ϵ), we define L flow post (θ, h t ; ϕ, ω, ψ, τ, ϵ) := w τ v ψ (z τ , h t , τ ; π ϕ , η ω ) -(ϵθ) Using the random policy, our posterior samples do not differ significantly from CoDiff’s in SSIM, indicating that the our method’s learned-policy performance gap over CoDiff is primarily due to an improved policy rather than differences in the sampling process or network size or architectures.

We used the same training settings and hardware as in the location finding experiment (see above, Section C.2), but omitted pretraining and performed only joint training for 400 epochs with fully adaptive rollouts (ρ n = 0). Training took approximately 2.78 hours, and deployment for a full rollout with horizon T = 10 and 10,000 posterior samples took approximately 0.67 seconds.

To extend the previous experiments to higher dimensions, we also consider image discovery as proposed by Iollo et al. (2025). Here, the goal is to reconstruct an image from partial information, with each measurement unveiling only part of the image. One can imagine this experiment like standing in a dark room and trying to observe a large poster on the wall by shining light on its different parts bit by bit.

More formally, consider an unknown ground-truth image θ ∈ R C×H×W with C channels, as well as height H and width W . At each experimental step, choose a location ξ ∈ [0, H] × [0, W ] which represents the continuous-space center of the measurement. The noise-free signal µ is then given by a smooth analog of a simple masking operation, which Iollo et al. (2025) choose as a convolution with a Gaussian kernel G s :

with A ξ (θ) as the masked image, smoothness parameter s and (x 1 , x 2 ) the pixel locations. Iollo et al. (2025) further propose replacing the Gaussian kernel with a bivariate logistic distribution, which then simplifies the signal to µ ξ,s (x 1 , x 2 ) = [S(x 1ξ 1 + h; s 1 ) + S(ξ 1 + hx 1 ; s 1 ) -1] [S(x 2ξ 2 + h; s 2 ) + S(ξ 2 + hx 2 ; s 2 ) -1] (54) with h the mask size and S(x, s) = 1 1+exp(-x/s) the sigmoid function with scale parameter s. The full measurement x is noisy, such that the observed value at each pixel becomes x ξ,s (x 1 , x 2 ) = µ ξ,s (x 1 , x 2 ) + η(x 1 , x 2 )

(55)

where we choose a uniform noise term η ∼ U(0, σ). We further clamp x to [0, 1] in order to preserve its range of values even in the presence of noise at signal values near 1. In practice, the scale s and noise level σ are small, such that the signal dominates within, and the noise term dominates outside of the masking area. This means useful information can only be extracted from within the masking area. We run multiple experiments with varying σ using s = 0.1 and a mask size of h = 7 for an image size of H = W = 28.

Just like (Iollo et al., 2025), we first run a noise-free experiment with σ = 0. However, given that the signal has support over the full image, we expect the full ground truth image to be recoverable from very small signal values in a single measurement. Therefore, we run additional experiments with σ = 10 -3 and σ = 10 -2 which destroy signal outside of the masking area. The difference in signal support is highlighted on a log scale in Figure 8.

We train for a total of 500 epochs, with the maximum number of measurements T , linearly scheduled from T i = 2 to T f = 6 within the first 5% of training steps. Similarly, we schedule the probability of design exploration ρ n with a cosine decay schedule, starting at 100% and decaying to 0% in the first 30% of training steps.

We use the AdamW optimizer (Loshchilov & Hutter, 2019) without weight decay. The learning rate is scheduled according to a OneCycleLR (Smith & Topin, 2018), with a maximum learning rate of 10 -4 , an initial division factor of 10, and a final division factor of 10 4 . The learning rate is ramped up to the maximum over the first 5% of training steps.

We use a batch size of 48, and make use of automatic mixed precision with PyTorch Lightning’s precision option “16-mixed” (Falcon & The PyTorch Lightning team, 2019). We further clip the gradient norm at a maximum value of 5.0.

The full model is very small at only 417K parameters. Training takes around 48 hours on an Nvidia RTX 4060-Laptop GPU, while inference for a batch of 300 posterior samples typically takes under 3 seconds. The rollout process for all 6 measurements in a batch of one requires around 22ms. JIT-compilation further improves this to just 110µs, which is significantly faster than comparable methods on this hardware. Using larger batches could lead to additional improvements in the per-sample figure.

Apart from the results in the main section Section 5.3, we present results with the posterior approximator and a random policy in Figure 6 and additional rollout examples in Figure 7.

We expect the method to scale well to high-dimensional problems, provided sufficient hardware resources and careful handling of potentially vanishing gradients during training for very large policy networks and long rollouts (e.g., by letting the network learn the design residual instead). The down path alternates between residual blocks and 2 × 2 spatial downsampling, collecting skip features, while the up path mirrors this structure with 2 × 2 upsampling and additional residual blocks that fuse the corresponding skips. A final group normalization, SiLU activation, and 3 × 3 convolution map the output to 16 channels, yielding the spatial summary state h t ∈ R 16×28×28 .

For posterior inference over the underlying digit, we reuse the same Simple U-Net architecture with three modifications compared to the history encoder. First, the stage channel widths are increased from (8, 16) to (16, 32). Second, the input now concatenates the ground truth digit θ ∈ R 1×28×28 , the current summary state h t ∈ R 16×28×28 , and a scalar diffusion time (log-SNR) embedding. The time input is mapped through a 4-dimensional sinusoidal embedding followed by a two-layer MLP, and the resulting scalar feature is broadcast to a 1 × 28 × 28 map. Concatenating θ, h t , and this time channel yields an 18-channel input on which the U-Net operates, with the same residual blocks, skip connections, and single bottleneck self-attention as in the history network. Finally, the last projection maps back to a single channel, so that the output again lies in R 1×28×28 . In the diffusion-based ID variant, this conditional U-Net parameterizes a v-prediction model with a cosine noise schedule, whereas in the flow-matching variant, the same backbone is used to parameterize as a conditional vector field v. The corresponding objectives and integration schemes are described in Section A.1 and Section A.2, respectively.

Reference

This content is AI-processed based on open access ArXiv data.

Start searching

Enter keywords to search articles

↑↓
ESC
⌘K Shortcut