Efficient Exact Gradient Update for training Deep Networks with Very Large Sparse Targets
An important class of problems involves training deep neural networks with sparse prediction targets of very high dimension D. These occur naturally in e.g. neural language models or the learning of word-embeddings, often posed as predicting the probability of next words among a vocabulary of size D (e.g. 200 000). Computing the equally large, but typically non-sparse D-dimensional output vector from a last hidden layer of reasonable dimension d (e.g. 500) incurs a prohibitive O(Dd) computational cost for each example, as does updating the D x d output weight matrix and computing the gradient needed for backpropagation to previous layers. While efficient handling of large sparse network inputs is trivial, the case of large sparse targets is not, and has thus so far been sidestepped with approximate alternatives such as hierarchical softmax or sampling-based approximations during training. In this work we develop an original algorithmic approach which, for a family of loss functions that includes squared error and spherical softmax, can compute the exact loss, gradient update for the output weights, and gradient for backpropagation, all in O(d^2) per example instead of O(Dd), remarkably without ever computing the D-dimensional output. The proposed algorithm yields a speedup of D/4d , i.e. two orders of magnitude for typical sizes, for that critical part of the computations that often dominates the training time in this kind of network architecture.
💡 Research Summary
The paper addresses a fundamental bottleneck in training deep neural networks whose output targets are extremely high‑dimensional (size D ≈ 10⁵–10⁶) but sparse (only K ≪ D non‑zero entries). In conventional architectures the last hidden layer h ∈ ℝᵈ is multiplied by a weight matrix W ∈ ℝᴰˣᵈ to produce a dense output o = Wh, after which a loss (e.g., squared error or softmax) is evaluated. Computing Wh and the corresponding gradient ∂L/∂W requires O(D·d) operations and memory accesses, which becomes prohibitive for large vocabularies or item sets.
The authors propose a mathematically exact alternative that avoids ever forming the D‑dimensional output. The key ideas are:
- Factorisation of the output weight matrix: W is expressed as the product V U, where U ∈ ℝᵈˣᵈ is a small square matrix and V ∈ ℝᴰˣᵈ stores the “column space” of W.
- Maintenance of a compact Gram matrix: Q = WᵀW = UᵀU is kept up‑to‑date. Because Q is only d × d, all operations involving it are cheap.
With these definitions the squared‑error loss can be rewritten as
L = ‖Wh − y‖² = hᵀQh − 2 hᵀUᵀ(Vᵀy) + yᵀy.
All terms are now either O(d²) (the quadratic form hᵀQh) or O(K·d) (the product Vᵀy, which only touches the K non‑zero entries of y). Consequently the loss, its gradient with respect to the hidden representation (∇ₕL = 2(Qh − Uᵀ(Vᵀy))) and the exact gradient with respect to W (∂L/∂W = 2(Wh − y)hᵀ) can be computed without ever materialising Wh.
The gradient update for W is decomposed into two cheap updates:
- U‑update: U ← U − 2η (Uh)hᵀ (O(d²)).
- V‑update: V ← V + 2η y (Uᵀ_new h)ᵀ (O(K·d + d²)).
After the updates, Q is refreshed using a Sherman‑Morrison‑Woodbury‑type formula, also in O(d²). The overall per‑example computational cost becomes O(d² + K·d), which for the common regime K ≈ d yields roughly ½·d² operations, a dramatic reduction compared with the naïve O(D·d) cost. The theoretical speed‑up factor is D/(4d); for D = 200 000 and d = 500 the authors report a 100× acceleration. Memory traffic is similarly reduced: only the K·d entries of V and the three d × d matrices (U, Uᵀ, Q) need to be accessed, instead of the full D·d weight matrix.
The method applies to a family of loss functions that depend only on the squared norm of the full output and on inner products involving the non‑zero target entries. This includes ordinary squared error and the “spherical softmax” (log ‖c_j‖² − log ∑_k ‖c_k‖²), which provides normalized class probabilities while satisfying the required algebraic form. The standard softmax is not covered, which is a limitation for many classification tasks.
The paper also sketches extensions to minibatch training. When processing a batch of size m, one can compute UᵀH and update Q collectively, still in O(m·d²) time. However, if m ≫ d, solving the associated linear system from scratch may become more efficient than the incremental Sherman‑Morrison updates.
Experimental validation on a CPU (Intel Core i7) with d = 500 and K ≈ 5 demonstrates that the proposed algorithm yields identical loss trajectories to the naïve back‑propagation while reducing wall‑clock time by roughly two orders of magnitude and cutting memory usage dramatically.
Strengths
- Exact gradient computation (no approximation) despite the massive output dimension.
- Simple linear‑algebraic derivation; implementation requires only a few matrix‑vector products and rank‑one updates.
- Substantial speed‑up and memory savings for the common K ≪ D, d ≈ K regime.
Limitations
- Restricted to loss functions that can be expressed via the total ℓ₂ norm and sparse inner products; standard softmax is excluded.
- The factorisation W = VU introduces an implicit constraint on the weight space; it is not clear whether this limits representational capacity in practice.
- Extending to large minibatches or to distributed settings requires careful handling of the Q updates.
Future directions suggested include: evaluating the impact of spherical softmax on classification accuracy in large‑scale language models, exploring dynamic rank‑adaptation of U and V, and designing distributed implementations that exploit the small‑matrix nature of the updates.
In summary, the paper delivers a clever algebraic trick that transforms the prohibitive O(D·d) computation of large sparse‑target networks into a tractable O(d²) procedure, opening the door to exact training of models that were previously forced to rely on hierarchical or sampling‑based approximations.
Comments & Academic Discussion
Loading comments...
Leave a Comment