Optimal scaling laws in learning hierarchical multi-index models
In this work, we provide a sharp theory of scaling laws for two-layer neural networks trained on a class of hierarchical multi-index targets, in a genuinely representation-limited regime. We derive exact information-theoretic scaling laws for subspace recovery and prediction error, revealing how the hierarchical features of the target are sequentially learned through a cascade of phase transitions. We further show that these optimal rates are achieved by a simple, target-agnostic spectral estimator, which can be interpreted as the small learning-rate limit of gradient descent on the first-layer weights. Once an adapted representation is identified, the readout can be learned statistically optimally, using an efficient procedure. As a consequence, we provide a unified and rigorous explanation of scaling laws, plateau phenomena, and spectral structure in shallow neural networks trained on such hierarchical targets.
💡 Research Summary
This paper develops a rigorous theory of scaling laws for two‑layer neural networks trained on hierarchical multi‑index (HMI) targets in a genuinely representation‑limited regime. An HMI target is defined as f★(x)=∑_{k=1}^{m★} a_k g_k(⟨w★_k,x⟩) where the directions w★_k are orthogonal, the coefficients a_k decay as k^{−γ} (γ>0), and each g_k is an even, smooth nonlinearity. The data consist of Gaussian inputs x∈ℝ^d and noisy labels y=f★(x)+√Δ ξ.
The authors first derive information‑theoretic lower bounds on the weighted mean‑squared error (MSE_γ) for recovering the subspace span(W★). Using Approximate Message Passing (AMP) and its state‑evolution, they compute the Bayes‑optimal MMSE_γ and show that, when the sample size n=α d with α≫1, the error scales as
MMSE_γ≈∑_{k=1}^{m★} (d/(n a_k^2)) a_k^2 = (d/n) m★,
with each individual index k becoming statistically detectable only after n_k≈i^{2γ} d samples. This reveals a cascade of sharp phase transitions: as n crosses each n_k, a new feature emerges, causing an abrupt drop in prediction error, while between transitions the error plateaus. The derived rates coincide with minimax bounds for quasi‑sparse recovery (e.g., LASSO) and extend earlier conjectures for quadratic networks to a broad class of nonlinear targets.
Next, the paper introduces a simple, target‑agnostic spectral estimator. It computes the top eigenvectors of a suitably constructed data‑label covariance matrix and maps them to estimates of the directions w★_k. The estimator can be interpreted as the infinitesimal‑learning‑rate limit of gradient descent on the first‑layer weights, or equivalently as an AMP‑derived optimal spectral method. The authors prove that this estimator achieves the Bayes‑optimal MMSE_γ, thereby attaining the optimal scaling laws. Moreover, the estimator’s performance exhibits the same hierarchical phase‑transition pattern predicted by the information‑theoretic analysis.
Finally, the authors propose a two‑stage training procedure for two‑layer networks f(x;Θ)=a^⊤σ(Wx+b). In the first stage, the spectral estimator provides an accurate representation Ŵ of the target subspace without any knowledge of the link functions g_k. In the second stage, the readout weights a (and bias b) are learned by ordinary least‑squares (or regularized regression) on the fixed representation. The excess risk of the resulting network matches the optimal MMSE_γ, showing that no additional statistical penalty is incurred after the representation is learned. This holds even in the presence of noise and under regularization.
Overall, the paper makes three major contributions: (1) exact Bayes‑optimal scaling laws for subspace recovery in hierarchical multi‑index models; (2) a provably optimal spectral algorithm that attains these limits and explains the observed plateaus and abrupt drops in learning curves; (3) a demonstration that two‑layer neural networks, trained via a simple two‑step procedure, can achieve these optimal rates, thereby providing the first rigorous, non‑lazy, feature‑learning scaling theory for shallow networks. The work bridges empirical observations of progressive concept learning, spectral structure of trained weights, and neural scaling laws with a solid mathematical foundation.
Comments & Academic Discussion
Loading comments...
Leave a Comment