Beyond Random: Automatic Inner-loop Optimization in Dataset Distillation
The growing demand for efficient deep learning has positioned dataset distillation as a pivotal technique for compressing training dataset while preserving model performance. However, existing inner-loop optimization methods for dataset distillation typically rely on random truncation strategies, which lack flexibility and often yield suboptimal results. In this work, we observe that neural networks exhibit distinct learning dynamics across different training stages-early, middle, and late-making random truncation ineffective. To address this limitation, we propose Automatic Truncated Backpropagation Through Time (AT-BPTT), a novel framework that dynamically adapts both truncation positions and window sizes according to intrinsic gradient behavior. AT-BPTT introduces three key components: (1) a probabilistic mechanism for stage-aware timestep selection, (2) an adaptive window sizing strategy based on gradient variation, and (3) a low-rank Hessian approximation to reduce computational overhead. Extensive experiments on CIFAR-10, CIFAR-100, Tiny-ImageNet, and ImageNet-1K show that AT-BPTT achieves state-of-the-art performance, improving accuracy by an average of 6.16% over baseline methods. Moreover, our approach accelerates inner-loop optimization by 3.9x while saving 63% memory cost.
💡 Research Summary
The paper tackles a fundamental inefficiency in dataset distillation (DD) – the inner‑loop optimization that simulates model training on synthetic data. Existing inner‑loop methods such as BPTT, truncated BPTT (T‑BPTT), and random‑truncated BPTT (RaT‑BPTT) rely on fixed or randomly placed truncation windows to reduce memory consumption, but they ignore the fact that neural networks exhibit distinct learning dynamics at different training stages. The authors first empirically verify that early training steps (preliminary phase) carry large gradients and high variation, middle steps are relatively insensitive, and late steps (post phase) have smaller gradients that are crucial for fine‑tuning. Controlled experiments on CIFAR‑10 with a ConvNet show that truncating preliminary steps in the early stage improves validation accuracy by ~2.9 %, while truncating post steps in the late stage yields ~1.8 % gain; middle‑stage truncation position has negligible effect.
Building on these observations, the authors propose Automatic Truncated Backpropagation Through Time (AT‑BPTT), a framework that dynamically adapts both the truncation position and the window size based on intrinsic gradient information, and that incorporates a low‑rank Hessian approximation to keep computational cost low. The three core components are:
-
Dynamic Truncation Position – For each timestep t, the L2 norm of the gradient ∥∇θ Lt∥₂ is normalized with a temperature‑controlled softmax to produce a probability P_trunc(t). In the early stage, timesteps are sampled proportionally to P_trunc(t); in the middle stage, sampling is uniform; in the late stage, timesteps are sampled proportionally to 1 − P_trunc(t). This stage‑aware probability distribution aligns truncation with the learning phase.
-
Adaptive Window Size – The magnitude of gradient variation |∥∇θ Lt∥₂ − ∥∇θ Lt‑1∥₂| is similarly normalized to obtain a weight η(t). The original window size W (used in RaT‑BPTT) is linearly transformed: W*(t) = W − d + 2d·η(t), where d controls the allowable deviation. Consequently, timesteps with high variation (typically early) receive larger windows, preserving more gradient information, while stable later timesteps receive smaller windows, saving computation.
-
Low‑Rank Hessian Approximation – Computing full Hessian‑vector products is prohibitive. AT‑BPTT replaces the full Hessian H_j with a low‑rank approximation obtained via random projection, dramatically reducing memory usage (by 63 %) while preserving enough second‑order information to guide the meta‑gradient.
The authors evaluate AT‑BPTT on four benchmarks: CIFAR‑10, CIFAR‑100, Tiny‑ImageNet, and ImageNet‑1K, using standard ConvNet architectures. Compared to the strongest prior inner‑loop baseline (RaT‑BPTT), AT‑BPTT achieves an average absolute accuracy improvement of 6.16 % (e.g., a 4 % gain on ImageNet‑1K). In addition to performance gains, the method speeds up inner‑loop training by 3.9× and reduces memory consumption by 63 %, making it practical for large‑scale DD tasks.
In summary, AT‑BPTT introduces a principled, data‑driven way to schedule truncation in the inner‑loop of dataset distillation. By leveraging gradient magnitude for stage identification, gradient variation for window sizing, and low‑rank Hessian approximations for efficiency, it overcomes the limitations of random truncation, delivering both higher distilled‑data quality and substantial computational savings.
Comments & Academic Discussion
Loading comments...
Leave a Comment