Orthogonal Self-Attention
Softmax Self-Attention (SSA) is a key component of Transformer architectures. However, when utilised within skipless architectures, which aim to improve representation learning, recent work has highlighted the inherent instability of SSA due to inducing rank collapse and poorly-conditioned Jacobians. In this work, we design a novel attention mechanism: Orthogonal Self-Attention (OSA), which aims to bypass these issues with SSA, in order to allow for (non-causal) Transformers without skip connections and normalisation layers to be more easily trained. In particular, OSA parametrises the attention matrix to be orthogonal via mapping a skew-symmetric matrix, formed from query-key values, through the matrix exponential. We show that this can be practically implemented, by exploiting the low-rank structure of our query-key values, resulting in the computational complexity and memory cost of OSA scaling linearly with sequence length. Furthermore, we derive an initialisation scheme for which we prove ensures that the Jacobian of OSA is well-conditioned.
💡 Research Summary
The paper addresses a fundamental instability of Softmax Self‑Attention (SSA) when used in deep Transformer architectures that lack skip connections and normalization layers. Prior work has shown that SSA can cause rank collapse—token representations quickly become rank‑one—and produce poorly‑conditioned Jacobians, which hinder training. To overcome these issues, the authors propose Orthogonal Self‑Attention (OSA), a novel attention mechanism that forces the attention matrix to be orthogonal.
Methodology
Given input token matrix X∈ℝ^{N×d}, queries Q and keys K are computed as Q = XW_Q and K = XW_K with projection matrices W_Q, W_K∈ℝ^{d×d_v}. A skew‑symmetric matrix S is formed as
S = α·√(1/d_v)·(QKᵀ – KQᵀ),
where α is a learnable scalar. Because S is skew‑symmetric, its matrix exponential A = exp(S) lies on the special orthogonal group SO(N), guaranteeing orthogonality. The OSA output is then
OSA(X) = A·X·W_V·W_O,
with value and output projections W_V, W_O.
Low‑rank exploitation
Since d_v ≪ N, the rank of S is bounded by r ≤ 2d_v. Theorem 2.1 shows that the exponential of S can be computed via a low‑dimensional matrix: let B∈ℝ^{N×r} be an orthonormal basis for the subspace spanned by Q and K. Then
exp(S) = I_N + B·(exp(BᵀSB) – I_r)·Bᵀ.
Thus only an r×r exponential (cost O(r³)) is required, reducing overall complexity to O(N·d_v² + d_v³), i.e., linear in sequence length N.
Basis construction
Two practical ways to obtain B are described: (1) a reduced QR decomposition of
Comments & Academic Discussion
Loading comments...
Leave a Comment