PSDNorm: Test-Time Temporal Normalization for Deep Learning in Sleep Staging
Distribution shift poses a significant challenge in machine learning, particularly in biomedical applications using data collected across different subjects, institutions, and recording devices, such as sleep data. While existing normalization layers, BatchNorm, LayerNorm and InstanceNorm, help mitigate distribution shifts, when applied over the time dimension they ignore the dependencies and auto-correlation inherent to the vector coefficients they normalize. In this paper, we propose PSDNorm that leverages Monge mapping and temporal context to normalize feature maps in deep learning models for signals. Evaluations with architectures based on U-Net or transformer backbones trained on 10K subjects across 10 datasets, show that PSDNorm achieves state-of-the-art performance on unseen left-out datasets while being more robust to data scarcity.
💡 Research Summary
The paper addresses the pervasive problem of distribution shift in biomedical time‑series, focusing on sleep staging from EEG recordings collected across many subjects, institutions, and devices. Conventional normalization layers—BatchNorm, LayerNorm, InstanceNorm—standardize activations using statistics computed over batches, channels, or samples, but they treat each time point independently and ignore the intrinsic temporal autocorrelation and spectral characteristics of signals. To overcome this limitation, the authors propose PSDNorm, a novel deep‑learning layer that normalizes intermediate feature maps by aligning their power spectral density (PSD) to a running Riemannian (Wasserstein‑2) barycenter.
The method consists of three steps. First, for each feature map in a batch, the PSD is estimated using the Welch method: the map is centered, segmented into overlapping windows of length f, Fourier‑transformed, and the squared magnitudes are averaged. Second, the PSDs of all batch samples are combined to compute a batch barycenter via a simple arithmetic mean in the log‑domain (Equation 5). This batch barycenter is then merged with a globally maintained running barycenter using an exponential geodesic average under the Bures metric (Equation 6), providing a smooth, gradual adaptation to the overall data distribution while preventing abrupt changes. Third, each centered feature map is transformed by an f‑Monge mapping that rescales its frequency components according to the ratio between the running barycenter PSD and its own PSD. The mapping is implemented as a convolution with a filter b_H derived from the square‑root of the PSD ratio and an inverse Fourier transform. The hyper‑parameter f controls the spectral resolution: f = ℓ yields a full‑scale Monge map (exact Gaussian alignment), while f = 1 reduces to a simple per‑channel scalar scaling.
Computationally, PSDNorm adds an overhead of O(N·c·ℓ·f·log f) per forward pass (N = batch size, c = channels, ℓ = signal length), dominated by FFT operations that are highly optimized on modern GPUs/TPUs. The layer is fully differentiable; gradients flow through the Monge mapping but stop at the running barycenter to keep it a non‑learnable statistic.
The authors evaluate PSDNorm on a massive benchmark comprising ten publicly available sleep datasets (≈10 k subjects, 10 M epochs). They adopt a leave‑one‑dataset‑out (LODO) protocol with three random seeds, training two backbone families—U‑Net‑style encoder‑decoder and transformer‑based models—both with and without PSDNorm. Compared to baseline normalizations, PSDNorm consistently improves macro‑averaged accuracy by 2.3–3.1 percentage points on unseen datasets, and it achieves comparable performance to the best baseline while using only 25 % of the labeled data. Ablation studies show that the method is robust to the choice of f (optimal values between 1 and 17) and that the benefit is larger for architectures that explicitly model temporal dependencies (e.g., transformers).
The paper also discusses limitations. PSDNorm relies on a Gaussian‑stationary assumption for the underlying signal, which may be violated in highly non‑stationary recordings. Very long sequences increase memory consumption for the overlapping windows, and an overly small f can make the method sensitive to high‑frequency noise. Future work is suggested to extend the barycenter concept to multimodal or mixture distributions, to develop online adaptive updates for real‑time monitoring, and to relax the Gaussian assumption.
In summary, PSDNorm introduces a principled, test‑time‑compatible normalization that leverages optimal‑transport‑based spectral alignment. By embedding this operation as a drop‑in layer, the authors demonstrate substantial gains in cross‑domain generalization for sleep staging, offering a practical solution for deploying deep learning models in heterogeneous clinical environments.
Comments & Academic Discussion
Loading comments...
Leave a Comment