SimpleGPT: Improving GPT via A Simple Normalization Strategy
In this work, we revisit Transformer optimization through the lens of second-order geometry and establish a direct connection between architectural design, activation scale, the Hessian matrix, and the maximum tolerable learning rate. We introduce a simple normalization strategy, termed SimpleNorm, which stabilizes intermediate activation scales by construction. Then, by analyzing the Hessian of the loss with respect to network activations, we theoretically show that SimpleNorm significantly reduces the spectral norm of the Hessian, thereby permitting larger stable learning rates. We validate our theoretical findings through extensive experiments on large GPT models at parameter scales 1B, 1.4B, 7B and 8B. Empirically, SimpleGPT, our SimpleNorm-based network, tolerates learning rates 3$\times$-10$\times$ larger than standard convention, consistently demonstrates strong optimization stability, and achieves substantially better performance than well-established baselines. Specifically, when training 7B-scale models for 60K steps, SimpleGPT achieves a training loss that is 0.08 lower than that of LLaMA2 with QKNorm, reducing the loss from 2.290 to 2.208. Our source code will be released at https://github.com/Ocram7/SimpleGPT.
💡 Research Summary
The paper revisits the optimization of large Transformer‑based language models from a second‑order geometry perspective, establishing a concrete link between architectural choices, activation scaling, the Hessian of the loss, and the maximum stable learning rate. Classical smoothness theory tells us that for a β‑smooth objective the gradient descent step size must satisfy η ≤ 2/β, where β equals the supremum of the Hessian spectral norm. Existing design heuristics—residual scaling, LayerNorm, RMSNorm, DeepNorm—are largely justified empirically, without a clear quantitative connection to the Hessian.
The authors introduce SimpleNorm, a minimalist normalization strategy that inserts a normalization operator immediately after every linear projection. Formally, SimpleNorm is defined as Ψ(x)=Norm(Wx), where Norm is instantiated as RMSNorm. The resulting mapping can be written as y = √d · γ ⊙ (Wx)/‖Wx‖₂, with γ a learnable per‑dimension scale. By construction, the output norm is bounded between γ_min √d and γ_max √d, guaranteeing that intermediate activations stay on the order of √d regardless of depth or weight growth.
A detailed Hessian analysis follows. The gradient with respect to the input x is ∇_xℓ = √d s Wᵀ P D g_y, where s = ‖Wx‖₂, P = I − uuᵀ (u = Wx/s), and D = Diag(γ). The Hessian decomposes into a Gauss‑Newton term L = Jᵀ H_yy J (J being the Jacobian of Ψ) and a curvature term C arising from the normalization itself. Under high‑dimensional, non‑pathological assumptions, the authors prove (Theorem 4.1) that ‖C‖₂ is asymptotically negligible compared to ‖L‖₂; specifically, ‖L‖₂ ≈ τ κ² ‖H_yy‖₂ while ‖C‖₂ ≤ 3 κ² √d ‖g_y‖₂, with κ = Θ(1). Consequently, the overall Hessian spectral norm of a SimpleNorm layer is far smaller than that of a plain linear layer, whose curvature scales as ‖W‖₂². This makes the loss function effectively smoother (β is reduced), allowing substantially larger learning rates without risking divergence.
Empirically, SimpleNorm is integrated into a GPT‑style architecture dubbed SimpleGPT. The authors replace every linear projection (Q, K, V, O, MLP weights) with the Ψ operator and discard the conventional pre‑normalization scaffold. Experiments span models of 1 B, 1.4 B, 7 B, and 8 B parameters, trained on the same token budget and schedule as strong baselines (LLaMA2 with QKNorm, DeepNorm, etc.). SimpleGPT tolerates learning rates 3–10× larger than the baselines while maintaining stable training dynamics. On a 7 B model trained for 60 K steps, SimpleGPT achieves a final training loss of 2.208 versus 2.290 for LLaMA2‑QKNorm, a reduction of 0.08. Additional diagnostics show lower gradient norms in early training, confirming that the smoother landscape translates into more controlled parameter updates.
The paper’s contribution is twofold: (1) it provides a rigorous second‑order justification for why placing normalization directly after linear maps improves optimization stability, and (2) it demonstrates that this simple design yields practical benefits at the scale of billions of parameters with negligible computational overhead. The authors release their code at https://github.com/Ocram7/SimpleGPT, inviting further exploration of normalization placement and its impact on Hessian geometry in future large‑scale language model research.
Comments & Academic Discussion
Loading comments...
Leave a Comment