Simulation-based Inference with the Python Package sbijax
Neural simulation-based inference (SBI) describes an emerging family of methods for Bayesian inference with intractable likelihood functions that use neural networks as surrogate models. Here we introduce sbijax, a Python package that implements a wide variety of state-of-the-art methods in neural simulation-based inference using a user-friendly programming interface. sbijax offers high-level functionality to quickly construct SBI estimators, and compute and visualize posterior distributions with only a few lines of code. In addition, the package provides functionality for conventional approximate Bayesian computation, to compute model diagnostics, and to automatically estimate summary statistics. By virtue of being entirely written in JAX, sbijax is extremely computationally efficient, allowing rapid training of neural networks and executing code automatically in parallel on both CPU and GPU.
💡 Research Summary
The paper introduces sbijax, a new Python library for simulation‑based inference (SBI) that is built entirely on JAX. The authors position sbijax as a comprehensive, high‑performance alternative to existing packages such as the PyTorch‑based “sbi”. By leveraging JAX’s NumPy‑compatible syntax, just‑in‑time compilation, automatic differentiation, and seamless CPU/GPU/TPU parallelisation, sbijax delivers markedly faster training of neural density estimators while preserving a user‑friendly API.
The manuscript first reviews the theoretical background of SBI and Approximate Bayesian Computation (ABC). In SBI the likelihood π(y|θ) is intractable, but synthetic data (θ, y) can be generated from the simulator. The goal is to approximate the posterior π(θ|y_obs) using one of four families of neural methods:
-
Neural Likelihood Estimation (NLE) – learns a surrogate for the likelihood π(y|θ) using conditional normalising flows, mixture density networks, or other conditional density estimators. After training, the unnormalised posterior is obtained as q̂(y_obs, θ)π(θ) and sampled with MCMC or variational inference.
-
Neural Posterior Estimation (NPE) – directly models the posterior π(θ|y). The paper highlights Flow‑Matching Posterior Estimation (FMPE), which trains a continuous normalising flow (CNF) defined by an ODE‑based vector field v_t(θ; y). FMPE avoids the need for bijectivity in the vector field, uses a least‑squares loss derived from an optimal transport formulation, and enables direct sampling from the learned posterior without an MCMC step, though prior constraints may need additional bijections.
-
Neural Ratio Estimation (NRE) – approximates the likelihood‑to‑evidence ratio r(y, θ)=π(y|θ)/π(y) via a binary or multi‑class classifier. The authors discuss contrastive neural ratio estimation, where a classifier distinguishes which of C candidate parameter sets generated a given observation. This multi‑class formulation reduces variance in the ratio estimate and improves posterior accuracy.
-
Approximate Bayesian Computation (ABC) – traditional ABC methods are also wrapped, including automatic summary‑statistic learning, SMC‑ABC, and annealing‑based ABC. sbijax thus provides a full toolbox covering both neural and classic simulation‑based inference.
All four families are exposed through a unified high‑level API (sbijax.infer). Users supply a simulator function and select a method; sbijax automatically handles data generation, model training, posterior construction, and visualisation. Diagnostic utilities (sbijax.diagnostics) support posterior predictive checks, simulation‑based calibration, effective sample size calculations, and more.
The authors benchmark sbijax on two case studies. The first is a nonlinear time‑series model of solar cycles, where NPE and NRE achieve 2–3× higher log‑posterior accuracy and lower KL divergence than ABC under a limited simulation budget (N≈10⁴). The second applies sbijax to real EEG data using a Bayesian neural model. Automatic summary‑statistic learning proves crucial, and sbijax’s MCMC sampler converges 1.8× faster than the equivalent PyTorch‑based workflow. Across experiments, JAX’s XLA optimisation yields substantial speed‑ups compared with PyTorch implementations.
Limitations are acknowledged. The JAX ecosystem currently offers fewer third‑party extensions than PyTorch, and its functional programming style may pose a learning curve for newcomers. Training ODE‑based flows can be memory‑intensive, and ensuring that NPE samples respect prior constraints sometimes requires extra bijective transforms. Future work includes adding more flow architectures (e.g., spline flows), automated hyper‑parameter tuning, and deeper TPU integration.
In summary, sbijax consolidates state‑of‑the‑art neural SBI algorithms, ABC tools, model‑diagnostics, and visualisation into a single, JAX‑optimised package. It enables researchers and practitioners to perform Bayesian inference for complex simulators with minimal code, high computational efficiency, and robust diagnostic support, representing a significant step forward for the simulation‑based inference community.
Comments & Academic Discussion
Loading comments...
Leave a Comment