Training Data Selection with Gradient Orthogonality for Efficient Domain Adaptation
Fine-tuning large language models (LLMs) for specialized domains often necessitates a trade-off between acquiring domain expertise and retaining general reasoning capabilities, a phenomenon known as catastrophic forgetting. Existing remedies face a dichotomy: gradient surgery methods offer geometric safety but incur prohibitive computational costs via online projections, while efficient data selection approaches reduce overhead but remain blind to conflict-inducing gradient directions. In this paper, we propose Orthogonal Gradient Selection (OGS), a data-centric method that harmonizes domain performance, general capability retention, and training efficiency. OGS shifts the geometric insights of gradient projection from the optimizer to the data selection stage by treating data selection as a constrained decision-making process. By leveraging a lightweight Navigator model and reinforcement learning techniques, OGS dynamically identifies training samples whose gradients are orthogonal to a general-knowledge anchor. This approach ensures naturally safe updates for target models without modifying the optimizer or incurring runtime projection costs. Experiments across medical, legal, and financial domains demonstrate that OGS achieves excellent results, significantly improving domain performance and training efficiency while maintaining or even enhancing performance on general tasks such as GSM8K.
💡 Research Summary
The paper tackles the persistent problem of catastrophic forgetting that occurs when large language models (LLMs) are fine‑tuned for specialized domains such as medicine, law, or finance. Existing remedies fall into two camps. Gradient‑surgery approaches (e.g., GEM, SafeGrad) explicitly project task gradients onto a subspace orthogonal to a “protected” direction representing general knowledge. While theoretically sound, they require computing reference gradients and performing high‑dimensional projections at every optimization step, which becomes prohibitively expensive for models with billions of parameters. Data‑selection methods (e.g., LESS, GrADS) reduce overhead by filtering training examples offline, but they are blind to the geometric relationship between domain and general‑task gradients, often selecting data that maximizes domain performance at the expense of general capabilities.
Orthogonal Gradient Selection (OGS) bridges this gap by moving the geometric safety check from the optimizer to the data‑selection phase. The core idea is to pre‑define a general‑knowledge anchor gradient g_ref, computed once from a small, curated set of tasks (e.g., GSM8K math problems, MMLU factual questions, Alpaca instructions). For each candidate training example x_i, OGS measures two metrics on a lightweight “Navigator” proxy model: (1) Orthogonality = 1 – |cos(g_i, g_ref)|, indicating how close the sample’s gradient is to the hyperplane perpendicular to g_ref; (2) Conflict = –cos(g_i, g_ref), a signed measure that is positive when the sample’s update opposes the anchor (risking forgetting) and negative when it is synergistic.
Because gradient directions are highly correlated across model scales, the Navigator (e.g., a 0.5 B Qwen3) can compute these metrics for the entire candidate pool at negligible cost. The resulting geometric signals are then fed into a reinforcement‑learning (RL) policy that learns to select batches maximizing domain learning speed while satisfying an orthogonality/conflict constraint. The RL reward balances three components: (a) domain loss reduction, (b) penalization of positive conflict scores, and (c) encouragement of high orthogonality. This formulation is equivalent to a first‑order approximation of a bilevel optimization problem that maximizes domain performance subject to a lower bound on post‑update general‑task performance.
OGS operates in two phases. Phase 1 (Strategy Learning) runs the Navigator over the full dataset, computes orthogonality and conflict scores, and trains the RL policy. Phase 2 (Strategy Application) applies the learned policy to the target LLM (e.g., 14 B parameters), selecting only those samples whose gradients are naturally orthogonal or synergistic with g_ref. Consequently, during actual fine‑tuning, no extra forward‑backward passes or projection operations are required, preserving the full throughput of standard LoRA‑based fine‑tuning pipelines.
Empirical evaluation spans three high‑stakes verticals—medical QA, legal case summarization, and financial risk assessment—using target models of 1.7 B, 7 B, and 14 B parameters. Compared with baseline data‑selection methods and with gradient‑surgery baselines, OGS achieves:
- Domain accuracy improvements of 3–5 percentage points over the best prior data‑selection baselines.
- Preservation (or slight improvement, up to +0.5 pp) of general‑task performance on GSM8K and MMLU, demonstrating effective mitigation of forgetting.
- Training speedups of roughly 2× relative to gradient‑surgery methods, with dramatically lower memory consumption because the expensive geometric computation is performed only once on the Navigator.
An ablation study shows that “active anchor” selection—choosing anchor examples that exhibit maximal conflict with domain data—further reduces forgetting, confirming the importance of a well‑chosen g_ref.
In summary, OGS introduces three key innovations: (1) offline geometric awareness via orthogonal and conflict scores, (2) a Navigator‑Target architecture that transfers costly gradient analysis from the large target model to a small proxy, and (3) an RL‑driven dynamic curriculum that jointly optimizes domain learning and stability. This data‑centric approach delivers the safety guarantees of gradient surgery while retaining the efficiency of conventional data selection, offering a practical, scalable solution for safe domain adaptation of today’s and tomorrow’s massive LLMs.
Comments & Academic Discussion
Loading comments...
Leave a Comment