BiSSL: Enhancing the Alignment Between Self-Supervised Pretraining and Downstream Fine-Tuning via Bilevel Optimization
Models initialized from self-supervised pretraining may suffer from poor alignment with downstream tasks, reducing the extent to which subsequent fine-tuning can adapt pretrained features toward downstream objectives. To mitigate this, we introduce BiSSL, a novel bilevel training framework that enhances the alignment of self-supervised pretrained models with downstream tasks prior to fine-tuning. BiSSL acts as an intermediate training stage conducted after conventional self-supervised pretraining and is tasked with solving a bilevel optimization problem that incorporates the pretext and downstream training objectives in its lower- and upper-level objectives, respectively. This approach explicitly models the interdependence between the pretraining and fine-tuning stages within the conventional self-supervised learning pipeline, facilitating enhanced information sharing between them that ultimately leads to a model initialization better aligned with the downstream task. We propose a general training algorithm for BiSSL that is compatible with a broad range of pretext and downstream tasks. Using SimCLR and Bootstrap Your Own Latent to pretrain ResNet-50 backbones on the ImageNet dataset, we demonstrate that our proposed framework significantly improves accuracy on the vast majority of 12 downstream image classification datasets, as well as on object detection. Exploratory analyses alongside investigative experiments further provide compelling evidence that BiSSL enhances downstream alignment.
💡 Research Summary
The paper tackles a fundamental mismatch problem in self‑supervised learning (SSL): representations learned by a pretext task are often poorly aligned with the downstream task, limiting the benefit of subsequent fine‑tuning. To address this, the authors propose BiSSL, a bilevel‑optimization framework that inserts an intermediate training stage between conventional SSL pretraining and downstream fine‑tuning.
In BiSSL, two copies of the backbone are maintained: θ P for the lower‑level (pretext) optimization and θ D for the upper‑level (downstream) optimization. The lower‑level problem minimizes the standard self‑supervised loss L P plus a regularization term λ r(θ P, θ D) that penalizes the distance between the two backbones. The upper‑level objective simultaneously minimizes the downstream loss L D using the lower‑level solution θ P*(θ D) and again using θ D itself:
min θ D, φ D L D(θ P*(θ D), φ D) + L D(θ D, φ D)
subject to
θ P*(θ D) = arg min θ P, φ P L P(θ P, φ P) + λ r(θ P, θ D).
This formulation makes the downstream objective directly influence the pretext optimization through the implicit Jacobian ∂θ P*/∂θ D. The authors derive the gradient of the upper‑level objective, showing it contains a term involving the inverse Hessian of the lower‑level loss. Because computing the exact inverse Hessian is infeasible for deep networks, they approximate the Hessian‑vector product with a Conjugate Gradient (CG) routine, following prior meta‑learning work.
Algorithm 1 alternates between lower‑level updates (standard SSL gradient plus the λ‑regularization gradient) and upper‑level updates (downstream gradients plus the CG‑approximated implicit Jacobian term). The number of upper‑level steps per alternation (N_U) is allowed to be greater than one, a deviation from classic BLO that the authors empirically find beneficial.
Experiments pretrain ResNet‑50 backbones on ImageNet‑1K using SimCLR and BYOL, then evaluate on twelve image‑classification benchmarks, COCO object detection, and ADE20K semantic segmentation. BiSSL consistently improves top‑1 accuracy by 2–4 percentage points on most classification datasets, with larger gains on domains far from ImageNet (e.g., sketches, medical images). Object detection mAP and segmentation mIoU also see statistically significant lifts. Ablation studies confirm the importance of the λ regularization and of the upper‑level term that directly incorporates downstream loss. Visualizations of representation similarity (CKA) across training stages reveal that BiSSL gradually reshapes the SSL features to become more downstream‑relevant, validating the intended alignment effect.
The paper’s contributions are threefold: (1) a principled bilevel formulation that jointly optimizes pretext and downstream objectives, (2) a practical training algorithm that scales to modern vision backbones via CG‑based Hessian‑vector approximation, and (3) extensive empirical evidence that this intermediate alignment stage yields robust performance gains across a variety of downstream tasks. Limitations include the additional computational overhead of CG and the current focus on vision; extending BiSSL to language or multimodal domains and exploring more efficient Hessian approximations are promising future directions. Overall, BiSSL offers a compelling, architecture‑agnostic method to bridge the gap between self‑supervised pretraining and downstream fine‑tuning, improving the utility of SSL representations without altering the standard pretraining pipeline.
Comments & Academic Discussion
Loading comments...
Leave a Comment