Weight Space Correlation Analysis: Quantifying Feature Utilization in Deep Learning Models
Deep learning models in medical imaging are susceptible to shortcut learning, relying on confounding metadata (e.g., scanner model) that is often encoded in image embeddings. The crucial question is whether the model actively utilizes this encoded information for its final prediction. We introduce Weight Space Correlation Analysis, an interpretable methodology that quantifies feature utilization by measuring the alignment between the classification heads of a primary clinical task and auxiliary metadata tasks. We first validate our method by successfully detecting artificially induced shortcut learning. We then apply it to probe the feature utilization of an SA-SonoNet model trained for Spontaneous Preterm Birth (sPTB) prediction. Our analysis confirmed that while the embeddings contain substantial metadata, the sPTB classifier’s weight vectors were highly correlated with clinically relevant factors (e.g., birth weight) but decoupled from clinically irrelevant acquisition factors (e.g. scanner). Our methodology provides a tool to verify model trustworthiness, demonstrating that, in the absence of induced bias, the clinical model selectively utilizes features related to the genuine clinical signal.
💡 Research Summary
The paper tackles a critical gap in the evaluation of medical imaging deep‑learning models: distinguishing between the mere presence of confounding metadata in the learned latent space and the actual use of that information in the model’s final decision. While prior work has shown that variables such as scanner type, hospital site, or patient demographics can be predicted from image embeddings, none have directly measured whether the classification head of a clinical model relies on those embeddings. To fill this void, the authors propose Weight Space Correlation (WSC) analysis, an interpretable methodology that quantifies feature utilization by comparing the linear decision directions (weight vectors) of a primary clinical task with those of auxiliary metadata tasks.
The method proceeds in three steps. First, each task is represented by the weight matrix of its linear classification head (the rows correspond to class‑specific weight vectors). Second, these weight vectors are projected onto a low‑dimensional data manifold obtained by principal component analysis (PCA) of the backbone’s latent embeddings. The top k components that explain 99 % of variance (with a minimum of 50 components) form a projection matrix P; each weight matrix is transformed to W′ = W Pᵀ, ensuring that only directions supported by actual data variance are compared. Third, cosine similarity is computed between every pair of projected class‑specific weight vectors from the clinical task and a metadata task, yielding a task‑pair correlation matrix. High cosine similarity indicates that the two tasks attend to the same latent directions, suggesting that the clinical model may be exploiting the metadata as a shortcut.
To establish a baseline, the authors first verify that metadata is indeed encoded in the embeddings. They train separate ResNet‑50 classifiers on two private ultrasound datasets (a fetal plane classification set and a cervical‑ultrasound preterm‑birth set) to predict each metadata factor individually. All metadata classifiers achieve strong accuracy, confirming that scanner model, pixel spacing, hospital ID, and other variables are recoverable from the raw images.
Next, they construct a null distribution of WSC values by measuring inter‑task cosine similarities between weight vectors from unrelated classification heads across the fetal dataset. This null distribution is centered near zero with a narrow spread, reflecting the orthogonality expected for unrelated tasks. In contrast, intra‑task similarities (different classes within the same head) show a slight negative bias due to the competitive nature of soft‑max training. This reference enables statistical detection of shortcut behavior: any inter‑task WSC value that deviates significantly from the null indicates shared feature utilization.
Applying WSC to the primary clinical model (SA‑SonoNet) trained for spontaneous preterm birth (sPTB) prediction, the authors find high correlations between the clinical head’s weight vectors and clinically relevant factors such as cervical length or birth weight, but near‑zero correlations with acquisition‑related metadata like scanner type. Thus, although the embeddings contain scanner information, the decision boundary for sPTB is orthogonal to the scanner direction, implying that the model does not rely on that shortcut.
To stress‑test this conclusion, a multitask version of the network is trained with auxiliary heads explicitly forced to predict all metadata simultaneously. Even in this setting, the correlation between the clinical head and scanner head remains low, reinforcing the claim that the model can store metadata in its latent space without using it for the primary prediction.
The paper’s contributions are threefold: (1) a principled metric (WSC) that directly measures alignment of decision directions across tasks, (2) a practical pipeline that leverages PCA to focus on data‑supported dimensions, and (3) empirical evidence that a state‑of‑the‑art obstetric ultrasound model can avoid shortcut learning despite abundant confounding signals. Limitations include the reliance on linear classification heads (extension to non‑linear heads or transformer‑style decoders would require adaptation) and sensitivity to the choice of PCA dimensionality, which may need dataset‑specific tuning.
Overall, the work provides a valuable tool for assessing model trustworthiness in settings where confounding factors are hard to control. By moving beyond group‑wise performance metrics and directly quantifying feature utilization, WSC analysis enables clinicians, regulators, and developers to verify that deep‑learning systems base their predictions on clinically meaningful signals rather than spurious artifacts. This approach is especially pertinent for ultrasound, where on‑screen annotations, device settings, and demographic cues are pervasive and can otherwise masquerade as predictive cues. The methodology could be extended to other imaging modalities and tasks, offering a general framework for shortcut detection and mitigation in medical AI.
Comments & Academic Discussion
Loading comments...
Leave a Comment