Evaluating Prediction Uncertainty Estimates from BatchEnsemble
Deep learning models struggle with uncertainty estimation. Many approaches are either computationally infeasible or underestimate uncertainty. We investigate \textit{BatchEnsemble} as a general and scalable method for uncertainty estimation across both tabular and time series tasks. To extend BatchEnsemble to sequential modeling, we introduce GRUBE, a novel BatchEnsemble GRU cell. We compare the BatchEnsemble to Monte Carlo dropout and deep ensemble models. Our results show that BatchEnsemble matches the uncertainty estimation performance of deep ensembles, and clearly outperforms Monte Carlo dropout. GRUBE achieves similar or better performance in both prediction and uncertainty estimation. These findings show that BatchEnsemble and GRUBE achieve similar performance with fewer parameters and reduced training and inference time compared to traditional ensembles.
💡 Research Summary
Deep learning models excel at predictive performance across many domains, yet they typically provide only point estimates and lack reliable uncertainty quantification. Accurate uncertainty estimates are crucial for safety‑critical applications because they enable detection of out‑of‑distribution (OOD) inputs and inform risk‑aware decision making. Uncertainty can be decomposed into aleatoric (data noise) and epistemic (model) components; the latter is especially important for OOD detection. Traditional Bayesian neural networks offer principled uncertainty but are computationally demanding. Monte‑Carlo (MC) dropout provides a lightweight Bayesian approximation but often suffers from over‑confidence. Deep ensembles, which train multiple independent models and average their predictions, achieve state‑of‑the‑art calibration and robustness but require a full copy of the network for each ensemble member, leading to prohibitive memory and compute costs.
BatchEnsemble addresses this trade‑off by sharing a single weight matrix per layer while introducing small, per‑member adapter vectors (input scaling r_k, output scaling s_k, and optional bias b_k). Each member’s forward pass consists of element‑wise scaling of the input, a shared linear transformation, and a second scaling of the output. Because the adapters add only O(p + q) parameters (where p and q are input and output dimensions), the total parameter count grows linearly with the ensemble size K rather than quadratically. Moreover, the method is fully vectorized: inputs are duplicated K times, adapters are stacked into matrices, and a single forward/backward pass computes losses for all members simultaneously. Training minimizes the average negative log‑likelihood across members, while inference averages the predictive distributions (a mixture of Gaussians for regression, class probabilities for classification).
To bring BatchEnsemble to sequential data, the authors propose GRUBE, a BatchEnsemble‑augmented GRU cell. All three linear transformations inside the GRU (reset, update, and candidate hidden state) are replaced with shared‑weight ensembles modulated by per‑member adapters. This yields K parallel hidden states at each time step, preserving epistemic diversity without replicating the full recurrent network. For multi‑step forecasting, the model uses ancestral sampling: each member draws stochastic trajectories, and the final predictive mean and variance are obtained by aggregating over members and sampled paths.
The experimental suite covers tabular regression (California Housing, Diabetes), tabular classification, and multivariate time‑series forecasting. For each task the authors compare five configurations: a single deterministic network, MC‑dropout, a deep ensemble (K = 5), BatchEnsemble (K = 5), and GRUBE (for the time‑series case). Evaluation metrics include predictive accuracy (RMSE or classification accuracy), proper scoring rules (negative log‑likelihood, Brier score), calibration measures (RMSCE, Expected Calibration Error, miscalibration area), selective prediction curves (coverage vs. performance), and an explicit aleatoric‑epistemic decomposition. Distribution‑shift robustness is assessed by constructing OOD test sets that exclude extreme quantiles of selected features.
Results show that BatchEnsemble consistently matches deep ensembles in predictive performance and uncertainty quality while using dramatically fewer parameters (5–10× reduction) and only modestly increased training/inference time (≈2–3×). On in‑distribution regression, BatchEnsemble ties or slightly outperforms deep ensembles in RMSE and achieves comparable NLL; MC‑dropout lags behind. Calibration is generally better for BatchEnsemble and deep ensembles than for MC‑dropout, though single models sometimes exhibit the lowest calibration error on specific datasets due to over‑estimation of uncertainty. Under distribution shift, BatchEnsemble and deep ensembles maintain stable RMSE and NLL, whereas MC‑dropout and the single model degrade sharply. In the time‑series domain, GRUBE attains equal or lower NLL and RMSE relative to deep ensembles and provides well‑calibrated predictive intervals. Selective prediction experiments reveal that BatchEnsemble’s uncertainty scores effectively rank difficult samples, preserving high accuracy when low‑uncertainty subsets are selected.
Ablation analyses indicate that the epistemic component captured by the adapters is the primary driver of improved calibration; aleatoric uncertainty is similarly estimated across all methods. The authors note that the choice of adapter dimensionality influences the balance between model capacity and ensemble diversity, and that very large ensemble sizes could re‑introduce memory bottlenecks despite the shared‑weight design.
In conclusion, the paper demonstrates that BatchEnsemble offers a practical, parameter‑efficient alternative to deep ensembles for uncertainty quantification across tabular and sequential tasks. The introduction of GRUBE extends these benefits to recurrent architectures, enabling accurate and calibrated multi‑step forecasts with substantially lower resource demands. Future work is suggested on optimizing adapter architectures, exploring asynchronous member updates, and scaling the approach to larger ensembles and transformer‑based time‑series models.
Comments & Academic Discussion
Loading comments...
Leave a Comment