Learning Graphical Model Parameters with Approximate Marginal Inference
Likelihood based-learning of graphical models faces challenges of computational-complexity and robustness to model mis-specification. This paper studies methods that fit parameters directly to maximize a measure of the accuracy of predicted marginals, taking into account both model and inference approximations at training time. Experiments on imaging problems suggest marginalization-based learning performs better than likelihood-based approximations on difficult problems where the model being fit is approximate in nature.
đĄ Research Summary
The paper addresses a fundamental difficulty in learning parameters of highâdimensional graphical models such as Markov random fields (MRFs) and conditional random fields (CRFs). Traditional maximumâlikelihood (ML) learning requires exact computation of the logâpartition function and the model marginals, which is intractable for graphs with large treeâwidth. Approximate ML approaches (e.g., pseudoâlikelihood, contrastive divergence) replace the exact marginals with those obtained from an approximate inference algorithm, but they still suffer from a mismatch: the inference algorithm used at test time is not taken into account during training, and the learning objective does not directly reflect the quality of the predictions that will actually be used.
The authors propose a âmarginalizationâbasedâ learning framework that directly optimizes a loss defined on the approximate marginals produced by the inference algorithm that will be employed at test time. In other words, instead of maximizing a surrogate likelihood, the training objective measures how close the inferred marginal distributions are to the groundâtruth labels (e.g., using crossâentropy or squared error). This approach has two major benefits: (1) it incorporates the inference approximation into the learning process, allowing the parameters to compensate for systematic inference errors; (2) it is more robust when the model is misspecified, because the loss directly reflects the quality of the final predictions rather than the fidelity of the underlying probabilistic model.
To make this framework practical, the paper develops two key technical contributions for computing gradients of the marginalâbased loss with respect to the model parameters:
-
Perturbationâbased gradient estimation â The authors observe that the marginal vector is the gradient of an (approximate) logâpartition function. By running the approximate inference algorithm twiceâonce with the current parameters θ and once with a slightly perturbed parameter vector θâŻ+âŻÎθâand measuring the change in the loss, they obtain a finiteâdifference estimate of the gradient. This âperturbationâ method is simple to implement, works with any blackâbox inference routine, and can be integrated with automaticâdifferentiation frameworks.
-
Truncated fitting â Conventional variational inference iterates until convergence before computing the loss, which is computationally expensive during training. The authors propose to stop the inference after a fixed number of updates (e.g., a few meanâfield sweeps or messageâpassing iterations) and use the resulting intermediate marginals in the loss. This truncated approach dramatically reduces training time while still providing useful gradient information, because the loss is evaluated on the same approximate marginals that will be used at test time.
The paper reviews two popular approximate inference algorithms within this framework:
-
Meanâfield (MF) â A fully factorized variational approximation that replaces the marginal polytope with a tractable subset. The MF updates are derived as blockâcoordinate ascent steps, and the approximate logâpartition function is a lower bound on the true value.
-
Treeâreweighted belief propagation (TRW) â A convex relaxation that expands the feasible set to the local polytope and replaces the true entropy with a tractable upper bound involving singleton entropies and mutual informations weighted by edge appearance probabilities Ď_c. When TRW messages converge, the resulting marginals are the maximizers of an upperâbounded variational objective.
Both algorithms fit naturally into the marginalâbased learning scheme: the loss is computed on the MF or TRW marginals, and gradients are obtained via perturbation or truncated fitting.
Experimental evaluation focuses on image segmentation tasks, a domain where accurate marginal predictions (pixelâwise class probabilities) are crucial. The authors train CRF models on standard datasets using MF and TRW as the inference backâends. They compare three training regimes:
- Exact or approximate ML (using pseudoâlikelihood or contrastive divergence),
- Marginalâbased learning with full convergence of the inference algorithm,
- Marginalâbased learning with truncated inference (few updates).
Results show that marginalâbased learning consistently outperforms MLâbased baselines in terms of pixel accuracy and IntersectionâoverâUnion (IoU). The advantage is most pronounced when the model is deliberately misspecified (e.g., using a simplified graph structure) or when the inference algorithm is far from exact. Moreover, truncated fitting reduces training time by 30â50âŻ% without sacrificing accuracy, confirming the practical benefit of the proposed approach.
Significance and limitations â By aligning the learning objective with the inference algorithm that will be used at deployment, the paper introduces a principled way to handle inference approximation and model misspecification. This âinferenceâawareâ learning can be applied to any differentiable approximate inference method, making it broadly relevant for computer vision, natural language processing, and other fields that rely on large graphical models. However, the method assumes that the chosen inference algorithm is stable; if messageâpassing diverges or meanâfield updates become trapped in poor local optima, gradient estimates may become noisy. Additionally, the design of the marginal loss (choice of divergence, weighting) remains problemâspecific and may require domain expertise.
In conclusion, the work presents a compelling alternative to likelihoodâbased training for graphical models. By directly optimizing the quality of approximate marginals, it achieves better predictive performance, greater robustness to model errors, and faster training through truncated inference. Future directions include extending the framework to more sophisticated variational families (e.g., structured meanâfield, amortized inference) and automating the selection of loss functions that best reflect downstream task metrics.
Comments & Academic Discussion
Loading comments...
Leave a Comment