The choice of attention mechanism in Transformer models involves a critical trade-off between modeling quality and inference efficiency. Multi-Head Attention (MHA) offers the best quality but suffers from large Key-Value (KV) cache memory requirements during inference. Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) reduce memory usage but often at the cost of model performance. In this work, we propose Mixture of Attention Schemes (MoAS), a novel architecture that dynamically selects the optimal attention scheme (MHA, GQA, or MQA) for each token via a learned router. We demonstrate that dynamic routing performs better than static averaging of schemes and achieves performance competitive with the MHA baseline while offering potential for conditional compute efficiency. Experimental results on WikiText-2 show that dynamic routing (val loss 2.3074) outperforms a static mixture (2.3093), validating the effectiveness of the proposed method. Our code is available at https://github.com/Esmail-ibraheem/Mixture-of-Attention-Schemes-MoAS.
Large Language Models (LLMs) based on the Transformer architecture [13] have achieved remarkable success. However, their deployment is constrained by the memory required to store the Key-Value (KV) cache during autoregressive generation.
Standard Multi-Head Attention (MHA) maintains unique keys and values for every query head, resulting in a memory footprint that scales linearly with the number of heads. To mitigate this, Multi-Query Attention (MQA) [10] shares a single key-value head across all query heads, significantly reducing memory bandwidth and capacity requirements. Grouped-Query Attention (GQA) [1] interpolates between these extremes by grouping query heads to share key-value pairs.
While MQA and GQA offer efficiency gains, they generally underperform MHA in terms of perplexity and downstream task accuracy. We hypothesize that not all tokens require the full expressivity of MHA. Some tokens may be adequately processed with the approximated context of MQA, while others require the fine-grained relationships captured by MHA.
To address this, we introduce Mixture of Attention Schemes (MoAS). Inspired by Mixtureof-Experts (MoE) [11], MoAS employs a router to dynamically weight or select between MHA, MQA, and GQA branches for each token. This allows the model to learn an optimal balance between quality and efficiency.
Efficient Transformers Numerous works have attempted to reduce the quadratic complexity of self-attention. Sparse Transformers [4] and Longformer [2] introduce fixed sparse patterns. Linformer [14] and Reformer [7] utilize low-rank approximations and hashing, respectively. FlashAttention [5] optimizes memory access patterns for hardware efficiency.
Mixture of Experts Conditional computation has been popularized by Mixture-of-Experts (MoE) models like the Switch Transformer [6], which route tokens to different feed-forward networks. Recently, Mixture-of-Depths [9] proposed dynamically allocating compute by routing tokens around blocks entirely. MoAS extends this philosophy specifically to the attention mechanism’s internal structure.
KV Cache Optimization As LLMs scale [3,12], KV cache management becomes critical. Page-dAttention [8] optimizes memory allocation. MQA [10] and GQA [1] structurally reduce the cache size. Our work builds directly on these structural innovations.
We define three distinct attention variants as our “experts”:
โข Type A: Multi-Head Attention (MHA): H Q = H KV = H. This is the standard mechanism with maximal expressivity.
โข Type B: Grouped-Query Attention (GQA): H Q = H, H KV = G, where 1 < G < H. This provides a middle ground. In our experiments, we use G = 2 for H = 6.
โข Type C: Multi-Query Attention (MQA): H Q = H, H KV = 1. This minimizes KV cache size but imposes the strongest bottleneck on the attention mechanism.
Given an input token representation x i โ R d , we compute the output of all schemes:
Everything is conditioned on a learned router that projects the input to a categorical distribution over the schemes:
where W 1 โ R d/4รd and W 2 โ R 3รd/4 form a lightweight Multi-Layer Perceptron (MLP). The final output y i for token i is the weighted sum:
To prevent the router from collapsing to a single scheme (e.g., always choosing MHA), we add an auxiliary load balancing loss:
This encourages uniform usage of all attention types on average across the batch.
We evaluate our method on the WikiText-2 language modeling benchmark. We train a decoder-only Transformer with the following specifications:
โข Layers: 4
โข Model Dimension (d model ): 384
โข Heads (H): 6
โข Block size: 256
โข Dropout: 0.1
We compare three models:
- Baseline MHA: Standard Transformer with MHA.
A static average of MHA, GQA, and MQA outputs (no routing).
The proposed method with learned routing.
All models are trained for 500 iterations with a batch size of 12 and learning rate 3 ร 10 -4 .
Table 1 presents the validation loss (perplexity-related metric) on WikiText-2. The Baseline MHA achieves the lowest loss, which is expected given the small scale and absence of capacity constraints. However, Dynamic MoAS outperforms Static MoAS (2.3074 vs 2.3093), confirming that the router learns non-trivial routing policies that are superior to simple averaging.
The parameter count for MoAS variants is higher because we instantiate all three attention branches in parallel for this proof-of-concept. In a production inference scenario, one would only execute the selected branch(es).
This content is AI-processed based on open access ArXiv data.