Lightweight and Interpretable Transformer via Mixed Graph Algorithm Unrolling for Traffic Forecast
Unlike conventional “black-box” transformers with classical self-attention mechanism, we build a lightweight and interpretable transformer-like neural net by unrolling a mixed-graph-based optimization algorithm to forecast traffic with spatial and temporal dimensions. We construct two graphs: an undirected graph $\mathcal{G}^u$ capturing spatial correlations across geography, and a directed graph $\mathcal{G}^d$ capturing sequential relationships over time. We predict future samples of signal $\mathbf{x}$, assuming it is “smooth” with respect to both $\mathcal{G}^u$ and $\mathcal{G}^d$, where we design new $\ell_2$ and $\ell_1$-norm variational terms to quantify and promote signal smoothness (low-frequency reconstruction) on a directed graph. We design an iterative algorithm based on alternating direction method of multipliers (ADMM), and unroll it into a feed-forward network for data-driven parameter learning. We periodically insert graph learning modules for $\mathcal{G}^u$ and $\mathcal{G}^d$ that play the role of self-attention. Experiments show that our unrolled networks achieve competitive traffic forecast performance as state-of-the-art prediction schemes, while reducing parameter counts drastically.
💡 Research Summary
The paper proposes a novel transformer‑like architecture for traffic forecasting that is both lightweight and interpretable. Instead of relying on the conventional self‑attention mechanism, the authors construct a mixed graph that captures spatial relationships through an undirected graph 𝔾ᵘ and temporal dependencies through a directed acyclic graph 𝔾ᵈ. For the undirected component they employ the classic graph Laplacian regularizer (GLR). For the directed component they introduce two new variational terms: (i) a directed graph Laplacian regularizer (DGLR) based on the symmetrized matrix Lᵈʳ = (Lᵈʳ)ᵀLᵈʳ, which penalizes squared differences between each node and its temporal predecessors, and (ii) a directed graph total variation (DGTV) defined as the ℓ₁‑norm of the difference between the signal and its random‑walk shifted version, capturing asymmetric variations. The combination of DGLR and DGTV yields an elastic‑net‑style regularization that improves robustness to noise.
The overall optimization problem minimizes a data‑fidelity term ‖y − Hx‖₂² together with the three regularizers (μᵤ·GLR, μ_{d,2}·DGLR, μ_{d,1}·DGTV). Because the objective mixes smooth ℓ₂ terms with a non‑smooth ℓ₁ term, the authors solve it via the Alternating Direction Method of Multipliers (ADMM). By introducing auxiliary variables (φ for the ℓ₁ term and zᵤ, z_d for the two ℓ₂ terms) they decompose each ADMM iteration into three linear systems that can be solved in linear time using Conjugate Gradient, and a soft‑thresholding step for the ℓ₁ sub‑problem.
Crucially, each ADMM iteration is “unrolled” into a neural network layer. The linear solves become differentiable operations, and all hyper‑parameters (μ, ρ, Lagrange multipliers) are treated as learnable tensors. This yields a feed‑forward network whose depth corresponds to the number of ADMM iterations. To emulate self‑attention, the authors periodically insert graph‑learning modules that update the adjacency matrices of 𝔾ᵘ and 𝔾ᵈ from the current hidden representations. These modules act as data‑driven attention mechanisms but retain a clear graph‑theoretic interpretation.
Experiments on public traffic datasets (e.g., METR‑LA, PEMS‑BAY) demonstrate that the unrolled model achieves forecasting accuracy comparable to state‑of‑the‑art transformer‑based methods such as PDFormer, while using only about 7 % of the parameters (≈1 M vs. ≈14 M). The reduction translates into lower memory consumption and faster inference, which are critical for real‑time traffic management. Ablation studies confirm that both DGLR and DGTV contribute meaningfully: removing DGLR degrades low‑frequency reconstruction, while omitting DGTV reduces robustness to outliers.
The paper acknowledges limitations: the graph‑learning modules are not yet as expressive as full nonlinear attention, and the size of the temporal window W can cause the directed graph to become dense. Moreover, validation is limited to traffic data; extending the framework to other spatio‑temporal domains (e.g., climate, power grids) remains future work. Potential extensions include multi‑scale graph constructions, adaptive ADMM parameter scheduling, and graph‑based positional encodings to further boost performance while preserving interpretability.
Comments & Academic Discussion
Loading comments...
Leave a Comment