PRISM: Parametrically Refactoring Inference for Speculative Sampling Draft Models
Large Language Models (LLMs), constrained by their auto-regressive nature, suffer from slow decoding. Speculative decoding methods have emerged as a promising solution to accelerate LLM decoding, attracting attention from both systems and AI research communities. Recently, the pursuit of better draft quality has driven a trend toward parametrically larger draft models, which inevitably introduces substantial computational overhead. While existing work attempts to balance the trade-off between prediction accuracy and compute latency, we address this fundamental dilemma through architectural innovation. We propose PRISM, which disaggregates the computation of each predictive step across different parameter sets, refactoring the computational pathways of draft models to successfully decouple model capacity from inference cost. Through extensive experiments, we demonstrate that PRISM outperforms all existing draft architectures, achieving exceptional acceptance lengths while maintaining minimal draft latency for superior end-to-end speedup. We also re-examine scaling laws with PRISM, revealing that PRISM scales more effectively with expanding data volumes than other draft architectures. Through rigorous and fair comparison, we show that PRISM boosts the decoding throughput of an already highly optimized inference engine by more than 2.6x.
💡 Research Summary
The paper tackles the fundamental latency bottleneck of large language model (LLM) decoding, which stems from the auto‑regressive nature of token generation. While speculative decoding (also called draft‑and‑verify) reduces the number of forward passes by letting a fast “draft” model propose a sequence of tokens that a larger “target” model verifies in a single pass, the overall speed‑up depends heavily on the quality of the draft model. Recent works have tried to improve draft quality by scaling up the draft model’s parameter count, but this introduces a proportional increase in per‑token compute, eroding the benefits of speculation.
PRISM (Parametrically Refactor Inference for Speculative sampling draft Models) proposes a novel architecture that decouples draft model capacity from its inference cost. The key insight is that the difficulty of predicting tokens varies across a generation sequence: early tokens are easier, while later tokens become increasingly uncertain. PRISM exploits this by assigning a distinct set of parameters to each drafting step, effectively creating a cascade of sub‑networks whose depth grows with the step index. Early steps use a shallow sub‑network, keeping latency low; later steps invoke deeper sub‑networks, providing the additional representational power needed for harder predictions. This conditional‑computing scheme is reminiscent of Mixture‑of‑Experts routing, but instead of routing based on input content, PRISM routes based on the generation step itself.
Because only a fixed subset of parameters is active at any given step, the total number of trainable parameters can be increased dramatically without raising the per‑token memory‑bandwidth cost. Consequently, PRISM achieves higher acceptance rates (the proportion of drafted tokens that the target model accepts) while maintaining a constant draft latency. Empirical results on LLaMA‑3‑8B show that PRISM outperforms prior draft architectures such as EA‑GLE‑2, EA‑GLE‑3, and HASS: acceptance lengths improve by 5–10 percentage points, and end‑to‑end decoding throughput gains exceed 2.6× when integrated into the highly optimized SGLang inference engine.
The authors also revisit scaling laws for draft models. Traditional vertical scaling (stacking more transformer layers) inflates both capacity and per‑step compute, leading to diminishing returns. PRISM’s parameter‑disaggregation breaks this coupling, allowing the total parameter count to grow while the active‑parameter count per step stays constant. Experiments across multiple data scales confirm that PRISM scales more efficiently than naïve depth‑increase approaches, delivering higher acceptance rates with lower latency even as training data volume grows.
From a systems perspective, the paper bridges the gap between AI‑centric speculative decoding research and practical deployment. By implementing PRISM inside SGLang—a state‑of‑the‑art inference engine—the authors demonstrate that the architectural gains translate into real‑world performance improvements beyond what is observable in pure PyTorch baselines (additional 1.3×–1.5× speedup).
In summary, PRISM introduces a paradigm shift for draft model design: it separates model capacity from inference cost through step‑wise parameter specialization, thereby achieving superior draft quality without sacrificing speed. This work paves the way for more efficient speculative decoding in production LLM services and suggests future directions such as combining PRISM with expert‑model techniques or extending the conditional‑computing principle to other auto‑regressive tasks.
Comments & Academic Discussion
Loading comments...
Leave a Comment