Rod Flow: A Continuous-Time Model for Gradient Descent at the Edge of Stability
How can we understand gradient-based training over non-convex landscapes? The edge of stability phenomenon, introduced in Cohen et al. (2021), indicates that the answer is not so simple: namely, gradient descent (GD) with large step sizes often diverges away from the gradient flow. In this regime, the “Central Flow”, recently proposed in Cohen et al. (2025), provides an accurate ODE approximation to the GD dynamics over many architectures. In this work, we propose Rod Flow, an alternative ODE approximation, which carries the following advantages: (1) it rests on a principled derivation stemming from a physical picture of GD iterates as an extended one-dimensional object – a “rod”; (2) it better captures GD dynamics for simple toy examples and matches the accuracy of Central Flow for representative neural network architectures, and (3) is explicit and cheap to compute. Theoretically, we prove that Rod Flow correctly predicts the critical sharpness threshold and explains self-stabilization in quartic potentials. We validate our theory with a range of numerical experiments.
💡 Research Summary
The paper tackles the puzzling “edge of stability” phenomenon observed when training neural networks with full‑batch gradient descent (GD) at large learning rates. Empirically, the sharpness (largest eigenvalue of the Hessian) rises until it reaches the critical value 2/η and then hovers there, while the loss continues to decrease on average. Classical continuous‑time analysis based on gradient flow (the ODE ˙w = −∇L) fails in this regime because it cannot capture the rapid oscillations that occur along the sharpest direction. A recent proposal, Central Flow (Cohen et al., 2025), models the time‑averaged trajectory by tracking a covariance matrix Σ constrained by a semidefinite complementarity problem, but its derivation is heuristic and its computational cost is high.
The authors introduce Rod Flow, a new ODE model derived from first principles. They rewrite the GD update in terms of the center (\bar w_t = (w_{t+1}+w_t)/2) and the half‑difference (\delta_t = (w_{t+1}-w_t)/2). The outer product (\delta_t\otimes\delta_t) is denoted Σ_t and interpreted as the “extent” of a one‑dimensional rod whose endpoints sample the loss at (\bar w_t\pm\delta_t). This representation is crucial: while (\delta_t) flips sign at each step when the system is unstable, (\delta_t\otimes\delta_t) remains positive and varies smoothly, making it amenable to ODE approximation.
Using exact difference equations for (\bar w_t) and Σ_t, the authors apply backward error analysis to obtain second‑order corrections to the naive continuous limit. The resulting ODEs are: \
Comments & Academic Discussion
Loading comments...
Leave a Comment