Gauss-Newton Unlearning for the LLM Era
Standard large language model training can create models that produce outputs their trainer deems unacceptable in deployment. The probability of these outputs can be reduced using methods such as LLM unlearning. However, unlearning a set of data (called the forget set) can degrade model performance on other distributions where the trainer wants to retain the model’s behavior. To improve this trade-off, we demonstrate that using the forget set to compute only a few uphill Gauss-Newton steps provides a conceptually simple, state-of-the-art unlearning approach for LLMs. While Gauss-Newton steps adapt Newton’s method to non-linear models, it is non-trivial to efficiently and accurately compute such steps for LLMs. Hence, our approach crucially relies on parametric Hessian approximations such as Kronecker-Factored Approximate Curvature (K-FAC). We call this combined approach K-FADE (K-FAC for Distribution Erasure). Our evaluation on the WMDP and ToFU benchmarks demonstrates that K-FADE suppresses outputs from the forget set and approximates, in output space, the results of retraining without the forget set. Critically, our method does this while altering the outputs on the retain set less than previous methods. This is because K-FADE transforms a constraint on the model’s outputs across the entire retain set into a constraint on the model’s weights, allowing the algorithm to minimally change the model’s behavior on the retain set at each step. Moreover, the unlearning updates computed by K-FADE can be reapplied later if the model undergoes further training, allowing unlearning to be cheaply maintained.
💡 Research Summary
The paper tackles a pressing problem in the deployment of large language models (LLMs): the need to suppress undesirable outputs that arise from specific training data (the “forget set”) while preserving the model’s performance on all other data (the “retain set”). Existing unlearning approaches for LLMs typically rely on first‑order gradient‑based methods that iteratively increase loss on the forget set and simultaneously constrain changes on the retain set. These methods suffer from instability, require careful hyper‑parameter tuning, and struggle to estimate the effect of each update on the massive retain distribution because they operate on small mini‑batches.
The authors propose a conceptually simple yet powerful alternative: compute a few “uphill” Gauss‑Newton steps using only the forget set, and embed the retain‑set constraint directly into the optimizer. The key insight is that the retain‑set constraint can be approximated by a second‑order Taylor expansion of the KL‑divergence between the original model’s output distribution and the updated model’s distribution. This yields a quadratic form defined by the Gauss‑Newton (or Fisher) matrix Gθ. The optimization problem becomes: maxδθ L_F(θ+δθ) subject to δθᵀ Gθ δθ ≤ ε, where L_F is a loss that encourages high probability of the unwanted outputs (so maximizing it reduces that probability). The solution is the natural gradient –Gθ⁻¹∇θ L_F(θ), i.e., a Gauss‑Newton ascent step. By normalizing the step with ‖∇L_F‖_{G⁻¹}, each update makes a roughly constant change in KL‑divergence, guaranteeing that the retain‑set behavior is minimally perturbed.
Directly computing Gθ⁻¹ is infeasible for modern LLMs with billions of parameters because the matrix is huge and often ill‑conditioned. To overcome this, the authors adopt Kronecker‑Factored Approximate Curvature (K‑FAC) and its eigen‑value‑corrected variant (EK‑FAC). K‑FAC approximates the Fisher matrix as a Kronecker product of two much smaller factors per layer (one derived from activations, the other from gradients), allowing efficient storage and inversion. EK‑FAC refines this approximation by correcting eigenvalues, improving fidelity without a large computational burden. These parametric Hessian approximations make it possible to compute a high‑quality Gauss‑Newton direction for LLMs in practice.
The resulting algorithm, named K‑FADE (K‑FAC for Distribution Erasure), proceeds as follows:
- Define the forget set D_F and a loss L_F that is high when the model produces the unwanted outputs.
- Estimate the Gauss‑Newton matrix Gθ on a retain set D_R using K‑FAC/EK‑FAC.
- Compute the natural gradient direction –Gθ⁻¹∇θ L_F(θ) and normalize it.
- Apply a step with a chosen learning‑rate η; repeat for a small number of iterations (typically 1–3).
Empirical evaluation is conducted on two recent benchmarks:
-
WMDP (Weapons of Mass Destruction Proxy) – measures the ability to suppress hazardous content while preserving general knowledge and fluency (MMLU, MT‑Bench). K‑FADE achieves substantially higher output suppression than prior first‑order methods and incurs only a marginal drop in the utility metrics, outperforming full retraining in terms of runtime.
-
ToFU (Test of Fictitious Unlearning) – tests removal of synthetic personal information while keeping non‑sensitive knowledge and matching the output distribution of a model trained without that data. A single Gauss‑Newton step with K‑FADE reaches or exceeds the “approximate retraining” baseline, delivering a new state‑of‑the‑art Pareto frontier between Forget Quality and model utility.
A notable practical advantage is that the computed K‑FADE update can be stored and reapplied after the model undergoes further fine‑tuning. Experiments show that, even after fine‑tuning on adversarial data, re‑applying the same update restores most of the unlearning effect, whereas other methods lose efficacy. This property enables cheap maintenance of unlearning in services that expose fine‑tuning APIs.
The paper also discusses limitations. No formal ε‑unlearning guarantees are provided, and the evaluation metrics (output suppression, KL‑divergence, benchmark scores) may not capture all real‑world utility or safety concerns. Moreover, the threat model differs from classic machine‑unlearning: here the model owner proactively removes undesirable behavior rather than responding to data‑subject deletion requests.
In summary, K‑FADE demonstrates that a few Gauss‑Newton ascent steps, efficiently approximated with K‑FAC, can achieve superior unlearning performance for LLMs—stronger output suppression, higher retain‑set fidelity, faster runtime than full retraining, and the ability to maintain unlearning across subsequent training phases. This work opens a promising direction for scalable, principled, and maintainable LLM safety interventions.
Comments & Academic Discussion
Loading comments...
Leave a Comment