Out-of-Support Generalisation via Weight-Space Sequence Modelling
As breakthroughs in deep learning transform key industries, models are increasingly required to extrapolate on datapoints found outside the range of the training set, a challenge we coin as out-of-support (OoS) generalisation. However, neural networks frequently exhibit catastrophic failure on OoS samples, yielding unrealistic but overconfident predictions. We address this challenge by reformulating the OoS generalisation problem as a sequence modelling task in the weight space, wherein the training set is partitioned into concentric shells corresponding to discrete sequential steps. Our WeightCaster framework yields plausible, interpretable, and uncertainty-aware predictions without necessitating explicit inductive biases, all the while maintaining high computational efficiency. Emprical validation on a synthetic cosine dataset and real-world air quality sensor readings demonstrates performance competitive or superior to the state-of-the-art. By enhancing reliability beyond in-distribution scenarios, these results hold significant implications for the wider adoption of artificial intelligence in safety-critical applications.
💡 Research Summary
The paper tackles the problem of out‑of‑support (OoS) generalisation, where test inputs lie in regions of the input space that contain no training data. Traditional out‑of‑distribution (OoD) methods either rely on strong inductive biases, distributionally robust optimisation, meta‑learning, or non‑parametric approaches such as Gaussian Processes (GPs). These solutions either require prior knowledge that is often unavailable or become computationally prohibitive for large datasets.
WeightCaster reframes OoS generalisation as a weight‑space sequence modelling problem. First, a distance metric is defined on the input space and an anchor point is chosen. The space is then partitioned into T concentric shells (called “rings”) of equal radius δ. Each training sample is assigned to a ring based on its distance to the anchor. The outermost ring that still contains training data is indexed as T_tr; rings with index > T_tr contain only test points.
Instead of learning a single parameter vector θ for the whole domain, WeightCaster learns a distinct parameter set θ_t for each ring t. The collection {θ_t} is treated as a time‑ordered sequence and modelled by a higher‑level neural functional G_ϕ, which predicts the next set of weights from the previous one (θ_{t+1}=G_ϕ(θ_t)). G_ϕ can be any sequence‑to‑sequence architecture (Transformer, LSTM, SSM, etc.). Training minimises the sum of per‑ring losses ℓ(f_{θ_t}(x),y) while jointly updating the initial weights θ_1 and the sequence model parameters ϕ. No supervision is required beyond the rings that contain training data, forcing the model to learn dynamics that can be extrapolated to unseen rings.
To obtain uncertainty estimates, the framework is made stochastic: G_ϕ outputs a mean μ_t and a diagonal standard‑deviation σ_t for the weight distribution at each step. Using the re‑parameterisation trick, θ_t = μ_t + σ_t ⊙ ε (ε∼N(0,I)) is sampled. A first‑order Taylor expansion of the downstream network f_θ(x) around μ_t yields an approximate predictive distribution p(y|x) ≈ N(μ_y, Σ_y) with μ_y = f_{μ_t}(x) and Σ_y = J diag(σ_t²) Jᵀ + σ_noise² I, where J is the Jacobian of f with respect to the weights. This linearisation provides a closed‑form covariance that captures both epistemic and aleatoric uncertainty.
A KL‑divergence regulariser between the predictive Gaussian and a standard normal prior (scaled by β) is added to the loss to discourage over‑confident predictions in OoS regions, encouraging the model to revert toward the prior when extrapolating far from the training support.
Experiments were conducted on two regression benchmarks. The synthetic cosine task uses y = cos(10x) + 0.5x + ε, with disjoint training and test intervals. The authors set T = 600 rings, T_tr = 300, and β = 1e‑2, using a linear regression model (slope and intercept) as f_θ. The sequence model G_ϕ is a simple 2×2 linear transformation matrix learned via gradient descent. The real‑world Air Quality task derives from the UCI dataset, predicting NOx from O₃ sensor readings after normalising and splitting the data at a threshold that creates a support shift. Here T = 80, T_tr = 40, and β = 5e‑2. Both tasks use the Adabelief optimiser in JAX.
Baselines include a standard multilayer perceptron (MLP), a Gaussian Process, and the Engression method (a recent distribution‑aware regression technique). Results show that WeightCaster dramatically reduces out‑of‑support mean‑squared error (MSE) compared with the MLP (e.g., 0.35 vs 2.37 on the cosine test set) and outperforms the GP (1.40 vs 0.35) while using far fewer parameters. On the Air Quality data, WeightCaster achieves test MSE = 0.138, comparable to Engression (0.160) and far better than the GP (0.705). In‑distribution errors are slightly higher than the MLP but remain acceptable, highlighting the trade‑off between in‑distribution accuracy and robust extrapolation.
The key advantages of WeightCaster are: (1) no need for explicit domain knowledge or handcrafted inductive biases; (2) computational efficiency comparable to standard parametric models, since the sequence model operates on a low‑dimensional weight vector rather than the full dataset; (3) principled uncertainty quantification via linearisation and KL regularisation; and (4) interpretability, as the learned weight trajectory can be visualised and analysed. Limitations include sensitivity to the choice of anchor point and ring width δ, potential inefficiency in high‑dimensional input spaces where constructing concentric shells may be costly, and the reliance on a first‑order Taylor approximation which may be inaccurate for highly non‑linear networks.
Future work could explore adaptive ring construction (e.g., data‑driven radii), richer stochastic dynamics (e.g., normalising flows in weight space), and extensions to classification, high‑dimensional vision, and time‑series domains. Overall, the paper presents a novel and promising direction for reliable out‑of‑support generalisation, combining weight‑space dynamics with uncertainty‑aware forecasting.
Comments & Academic Discussion
Loading comments...
Leave a Comment