Improving Discrete Optimisation Via Decoupled Straight-Through Estimator
The Straight-Through Estimator (STE) is the dominant method for training neural networks with discrete variables, enabling gradient-based optimisation by routing gradients through a differentiable surrogate. However, existing STE variants conflate two fundamentally distinct concerns: forward-pass stochasticity, which controls exploration and latent space utilisation, and backward-pass gradient dispersion i.e how learning signals are distributed across categories. We show that these concerns are qualitatively different and that tying them to a single temperature parameter leaves significant performance gains untapped. We propose Decoupled Straight-Through (Decoupled ST), a minimal modification that introduces separate temperatures for the forward pass ($τ_f$) and the backward pass ($τ_b$). This simple change enables independent tuning of exploration and gradient dispersion. Across three diverse tasks (Stochastic Binary Networks, Categorical Autoencoders, and Differentiable Logic Gate Networks), Decoupled ST consistently outperforms Identity STE, Softmax STE, and Straight-Through Gumbel-Softmax. Crucially, optimal $(τ_f, τ_b)$ configurations lie far off the diagonal $τ_f = τ_b$, confirming that the two concerns do require different answers and that single-temperature methods are fundamentally constrained.
💡 Research Summary
The paper addresses a fundamental limitation of existing Straight‑Through Estimators (STE) used for training neural networks with discrete variables. Current STE variants—Identity STE, Softmax STE, and Straight‑Through Gumbel‑Softmax (ST‑GS)—conflate two distinct design questions: (1) how much stochasticity should be introduced in the forward pass to encourage exploration and latent‑space utilization, and (2) how gradients should be distributed across categories in the backward pass to balance signal strength against the risk of “dead” categories. By tying both aspects to a single temperature parameter, these methods force a compromise that leaves performance on the table.
To resolve this, the authors propose Decoupled Straight‑Through (Decoupled ST), a minimal modification that introduces separate temperatures for the forward pass (τ_f) and the backward pass (τ_b). In the forward pass, logits are scaled by τ_f and a categorical sample (or argmax when τ_f→0) is drawn from the resulting softmax distribution p_f. In the backward pass, a distinct softmax with temperature τ_b yields p_b, and the gradient is propagated through the Jacobian J(τ_b)=1/τ_b
Comments & Academic Discussion
Loading comments...
Leave a Comment