Reference-Guided Machine Unlearning
Machine unlearning aims to remove the influence of specific data from trained models while preserving general utility. Existing approximate unlearning methods often rely on performance-degradation heuristics, such as loss maximization or random labeling. However, these signals can be poorly conditioned, leading to unstable optimization and harming the model’s generalization. We argue that unlearning should instead prioritize distributional indistinguishability, aligning the model’s behavior on forget data with its behavior on truly unseen data. Motivated by this, we propose Reference-Guided Unlearning (ReGUn), a framework that leverages a disjoint held-out dataset to provide a principled, class-conditioned reference for distillation. We demonstrate across various model architectures, natural image datasets, and varying forget fractions that ReGUn consistently outperforms standard approximate baselines, achieving a superior forgetting-utility trade-off.
💡 Research Summary
Machine unlearning (MU) seeks to erase the influence of specific training examples from a deployed model without the prohibitive cost of full retraining. Existing approximate MU techniques typically rely on heuristics that force the model to perform poorly on the “forget” set—such as maximizing loss, assigning random labels, or applying aggressive gradient ascent. While these signals are easy to compute, they are often poorly conditioned: the resulting gradients can be excessively large or misdirected, causing decision boundaries to shift far beyond the intended region and degrading overall generalization. Consequently, many recent works introduce constraints (e.g., staying close to the original parameters, repair mechanisms, or constrained parameter editing) to balance forgetting against stability, yet the fundamental mismatch between the proxy loss and the true objective—making the model behave as if it has never seen the forgotten data—remains unresolved.
The authors propose a new paradigm called Reference‑Guided Unlearning (ReGUn). The central idea is to replace the “make‑wrong” signal with a distributional indistinguishability objective: the model’s predictions on forget examples should be indistinguishable from its predictions on data it has truly never encountered. To obtain a concrete reference for “unseen behavior,” ReGUn leverages a disjoint held‑out dataset (D_h) that is not used during the original training. For each minibatch of forget samples (B_f), the method first computes the class histogram of the batch, then draws a matching set of (m) held‑out examples (\tilde D_h) whose class frequencies mirror those of (B_f). The reference model (chosen as the initial model (f_{\theta_0}) to avoid extra training) produces softmax probabilities on (\tilde D_h); these probabilities are averaged to form a soft target distribution (q(B_f)). This reference is class‑conditioned, ensuring that the target reflects the same label priors as the forget batch while still representing “unseen” inputs.
During the unlearning phase, two loss terms are jointly minimized: (1) a KL‑divergence term (\text{KL}(q(B_f) ,|, p_\theta(\cdot|x))) applied to each forget example, which is equivalent to a distillation loss toward the held‑out teacher distribution, and (2) a standard cross‑entropy loss on a retain minibatch (B_r) to preserve performance on the remaining data. Hyperparameters (\lambda_f) and (\lambda_r) control the trade‑off between forgetting strength and utility preservation. The full algorithm iterates over (T) steps, updating parameters with stochastic gradient descent using the combined loss.
Empirical evaluation covers three image classification benchmarks—CIFAR‑10, CIFAR‑100, and Tiny‑ImageNet—using two model families: a CNN (ResNet‑18) and a Vision Transformer (Swin‑T). Forget fractions of 1 %, 10 %, and 50 % are examined, with each configuration repeated over three random seeds. The authors assess (i) retained accuracy (RETAIN‑ACC), (ii) forget accuracy (FORGET‑ACC), (iii) test accuracy (TEST‑ACC), and (iv) membership inference risk using the robust membership inference attack (RMIA) AUC. They also report a composite metric GAP‑RFTP Avg, the average deviation of each method from a full‑retrain‑from‑scratch baseline across the four primary metrics.
Results show that ReGUn consistently yields the smallest GAP‑RFTP Avg across most settings, indicating that its outputs are closest to those of a model trained without the forget set. In the CNN experiments, ReGUn matches or exceeds baseline methods in test accuracy while achieving notably lower RMIA scores, demonstrating stronger privacy protection without sacrificing utility. The Transformer experiments reveal an even clearer advantage: at higher forget fractions (especially 50 %), ReGUn is the only method that reduces RMIA AUC to the level of the full‑retrain baseline, while maintaining competitive test accuracy. This suggests that the reference‑guided distillation provides a stable forgetting signal that does not require aggressive loss‑ascent updates, which can be particularly destabilizing for attention‑based architectures.
The paper also highlights that many recent sophisticated unlearning techniques (e.g., SALUN, AMUN) still underperform compared to the simple baseline NEG‑GRAD+ and FINE‑TUNE in the Transformer setting, underscoring a gap in current research that focuses predominantly on CNNs.
Limitations are acknowledged: the reference model is fixed to the initial parameters (f_{\theta_0}), which still contain some influence from the forget set, potentially limiting the purity of the “unseen” behavior. The authors argue this choice prevents reference drift, but future work could explore updating the reference or using an external oracle model. Additionally, the method assumes access to a sufficiently large, labeled held‑out dataset; the practicality of acquiring such data in real‑world deployments warrants further investigation.
In summary, Reference‑Guided Unlearning reframes the unlearning objective from “make the model wrong on forgotten data” to “make the model’s behavior on forgotten data indistinguishable from its behavior on truly unseen data.” By grounding the forgetting signal in a class‑conditioned held‑out reference distribution, ReGUn achieves a more stable optimization, better privacy guarantees, and competitive utility across both CNN and Transformer architectures. This work offers a principled, easily implementable framework that could become a new standard for practical machine unlearning in compliance‑driven AI systems.
Comments & Academic Discussion
Loading comments...
Leave a Comment