SPIRe: Boosting LLM Inference Throughput with Speculative Decoding
Speculative decoding (SD) has been shown to reduce the latency of autoregressive decoding (AD) by 2-3× for small batch sizes. However, increasing throughput and therefore reducing the cost per token requires decoding with large batch sizes. Recent work shows that SD can accelerate decoding with large batch sizes too if the context is sufficiently long and the draft model’s KV cache is sparse. We introduce SPIRe, a draft model that combines static sparse attention, pruned initialization, and feedback memory to increase the modeled throughput of speculative decoding by over 100% compared to speculation with a much smaller draft model and by over 35% compared to the strong baseline of sparse self-speculation. Our approach is particularly effective when context lengths vary significantly across requests. A PDF version of this article is available at https://arxiv.org/abs/2504.06419.
Introduction
Speculative decoding uses a cheap draft model to generate candidate tokens that are verified in parallel by an expensive target model (Leviathan et al. (2023); Chen et al. (2023)). Accepted draft tokens are guaranteed to be samples from the target model’s distribution over the next token by a modified rejection sampling scheme. If draft tokens are generated quickly enough and have a high enough acceptance rate, speculative decoding improves both the latency and throughput of decoding by reducing the number of memory accesses needed to generate each token1.
Optimizing speculative decoding for higher throughput requires different strategies than optimizing for lower latency; the latter is well-documented in the literature (Xia et al. (2024); Miao et al. (2024); Yan et al. (2025)). For example, low latency can often be achieved with small batch sizes (Pope et al. (2022)) and small draft models are preferred in this setting since fetching weights is the bottleneck. On the other hand, high throughput requires large batch sizes, which enables the use of larger draft models since in this regime loading the KV cache is the bottleneck (Chen et al. (2024)). See Figure 1 and Figure 4 for examples of how the speedup due to speculative decoding varies with draft model architecture, batch size, and context length according to our performance model below.
Figure 1: Speedup of generating tokens in a batch of size given a context of length . A round of speculation is the generation and subsequent verification of each block of draft tokens, the average generation length is the average number of tokens generated per round of speculation, and speedup is divided by how much longer a round of speculation takes compared to autoregressively decoding one token.
Prior work has shown that the acceptance rate of draft tokens can be increased using the target model’s intermediate activations and output logits (Du et al. (2024); Zhou et al. (2024)). Since a target model’s state can be cached during training at negligible cost, and since inference rather than training accounts for most of the cost of serving an LLM, the investment in training a draft model aligned to a particular target model may be worthwhile. Indeed, we show below that this is the case for SPIRe under conservative assumptions.
Our work, along with Chen et al. (2024), demonstrates that speculative decoding is an effective technique for reducing the cost of decoding, and our work enables more efficient exploration of the design space of draft models based on the batch sizes and context lengths expected in production. Our main contributions are the following:
- An implementation-agnostic model for evaluating the throughput of speculative decoding with different draft models.
- An efficient method for training a Feedback Transformer draft model.
- A draft model, SPIRe, that increases the modeled throughput of speculative decoding by over 100% compared to vanilla speculative decoding and by over 35% compared to the strong baseline of MagicDec (Chen et al. (2024)).
Evaluation Metric
It is problematic to compare different methods of training the draft model using the throughput achieved by the resulting models in a production environment. Measured throughput depends on myriad implementation details such as use of specialized kernels (Dao et al. (2022)), how models are partitioned across chips (Pope et al. (2022)), how queries are batched (Daniel et al. (2023)), and which queries are used for evaluation (Zhou et al. (2024)). Any of these variables could be confounders when comparing different draft model architectures.
To derive a proxy for measured throughput, we first note that the throughput of speculative decoding is equal to the throughput of vanilla decoding multiplied by where is the maximum speculation depth, a round of speculation is the generation and subsequent verification of each block of draft tokens, and are the latencies of a draft and target model forward pass on a single token, and is the latency of a target model forward pass on draft tokens in parallel. In the sequel, we refer to the average generation length and the iteration time multiplier , so that the throughput multiplier is simply .
To obtain a more implementation-agnostic expression for , we can break down the cost of a forward pass into the compute cost and the memory cost, which are the number of FLOPs and the number of memory accesses that need to be performed. The memory cost can be further broken down into the size of the LLM’s weights and the size of its KV cache. By multiplying the memory cost in bytes with the hardware operational intensity (HOI) in FLOPs per byte, we can obtain an estimate of the memory cost in FLOP-equivalents and compare the compute and memory costs on equal footing. Assuming that computation and memory accesses are maximally overlapped2, we obtain the following expression for the cost of a forward pass.
Finally, by substituting , , and with their corresponding forward costs, we obtain the following implementation-agnostic3 approximation of the throughput multiplier and use it as our evaluation metric. Here, is the number of FLOPs performed during a forward pass through the draft model, is the size of the draft model’s weights in bytes, is the size of the draft model’s KV cache in bytes, and similarly for , , and .
SPIRe
We combine the following techniques when training our draft model SPIRe in order to increase the throughput multiplier of speculative decoding. The name SPIRe stands for Sparse KV cache, Pruned Initialization, and Recurrent Feedback Transformer.
Sliding window KV cache with attention sink
In the large-batch long-context regime, the memory cost of loading the KV cache is the dominant cost of decoding, and reducing the size of the KV cache can reduce the iteration time multiplier by a larger factor than it reduces the average generation length . Following MagicDec we use a StreamingLLM attention mask4 (Xiao et al. (2024)), during training and inference, to ensure the draft model’s KV cache is sparse and that the size of the KV cache is constant with respect to the decoding sequence length. As shown in Figure 4, this results in an increase in the throughput multiplier as the batch size increases, which is never the case for vanilla speculative decoding.
Initialize by pruning the target model
In our experiments with an eight-layer target model, we initialized the draft model using the embedding layer, the last two transformer blocks, and the unembedding layer of the trained target model. As shown in Figure 5, initializing by pruning the target model produced a significantly higher value of than random initialization. Matching the embedding dimension of the target and draft models allowed us to further improve the quality of the draft model, as we’ll explain below. Future work could explore more sophisticated pruning strategies, such as the one described by Muralidharan et al. (2024).
With short-to-medium contexts, MagicDec may underperform vanilla speculative decoding and even autoregressive decoding because the large increase in FLOPs outweighs the small decrease in KV cache loads. Given that context lengths vary considerably in production, we sought a draft model which accelerates decoding across a wide range of context lengths. Consequently, our draft model has th as many parameters as the target model (ignoring embedding and unembedding parameters, which become negligible as the models are scaled up).
Feedback Transformer and attending to target model activations
Feedback memory (Fan et al. (2021)) was proposed to increase the representation capacity of Transformers, creating Feedback Transformers with stronger performance at any given inference cost. They are significantly more expensive to train, however, because they inhibit parallelism over the context length. Fortunately, we find that this disadvantage can be almost entirely mitigated when training a draft model, since we can arrange for very short draft rollouts. Furthermore, the additional performance due to feedback memory has been shown to be even greater for shallow models such as ours, and a shallower draft model is preferable for its smaller forward pass latency .
The -th attention sublayer of the Feedback Transformer computes using memory vectors instead of the standard . The memory vectors are defined as where is the activation of the -th token after the -th layer, are learnable weights, and layer 0 is the embedding layer. The dependence of each memory vector on all activations from the previous timestep does not increase the latency of decoding tokens autoregressively, but it does prevent training and prefill from being parallelized.
We require our Feedback Transformer draft model to be able to generate blocks of tokens given contexts of various lengths, where is the maximum speculation depth. Thus during training, rather than performing forward passes on a prefix of length , we perform just forward passes: we forego computing the first memory vectors by substituting them with target model activations.
Figure 2: The matrix in the 1st and 4th forward passes on a training sequence of length with a maximum speculation depth of . In general we perform forward passes on every -th prefix of a sequence, in a process similar to batched autoregressive decoding with teacher forcing.Specifically, during training, our Feedback Transformer draft model autoregressively generates tokens following every -th prefix of a sequence by performing forward passes. Unlike Fan et al. (2021), we share neither key and value projections nor memory vectors across layers: for , where is the length of a prefix, the memory vectors for the -th layer are defined as where are learnable weights. For , we let , where is the target model activation of the -th token after the -th layer. In other words, the weights of the first (second) layer of the draft model are initialized from the weights of the seventh (eighth) layer of the target model, and we let , for .
If we did not match the embedding dimension of the draft and target model, we could have substituted past memory vectors with draft model embeddings, but this was worse than substituting with target model activations in our experiments. Our approach was also better than sharing memory vectors across layers and substituting past memory vectors with either draft model embeddings or post-final-layer target model activations.
Distillation loss
Early implementations of speculative decoding for LLMs (Leviathan et al. (2023); Chen et al. (2023)) optimize the draft model’s parameters by minimizing the standard cross-entropy loss where is a prefix sampled from a dataset , is a one-hot vector corresponding to the token that follows , is a dense vector corresponding to the draft model’s probability distribution of the token that follows , and is the vocabulary size. To train a draft model to produce outputs that are more similar to the outputs of the target model, we minimize where is a dense vector corresponding to the target model’s probability distribution of the token that follows and is the expected acceptance probability (Leviathan et al. (2023)). Minimizing the first term minimizes the distillation loss (Hinton et al. (2015)), and minimizing the second term maximizes the average generation length . As shown in Figure 5, marginally outperformed .
Evaluation
We fix an 8-layer, 67-million body-parameter multi-head attention target model and vary the draft model, with all model architecture details provided in the Appendix. Each model with total-parameters is trained on tokens (Hoffmann et al. (2022)) from the LongCrawl64 dataset using sequences of length 1024. We compare drafting tokens for the target model using the following draft models:
- Vanilla speculative decoding (Chen et al. (2023)). This draft model has th as many body parameters as the target model, and is trained by minimizing the standard cross-entropy loss. This is the most widely understood implementation of speculative decoding, and, as can be seen in Figures 1 and 4, it is strong for short contexts and small batch sizes.
- MagicDec (Chen et al. (2024)). This draft model is identical to the target model, but it sparsely accesses its KV cache through a StreamingLLM attention mask with a window size of 64 and a sink size of 1. This is an elegant implementation of speculative decoding, and is strong in the large-batch long-context regime.
- SPIRe (ours). This draft model was described above in detail, and it has th as many body parameters as the target model. During both training and inference, we use a StreamingLLM attention mask with a window size of 64 and a sink size of 1.
We measure average generation lengths by generating tokens given contexts from the validation split of the LongCrawl64 dataset, and we calculate iteration time multipliers for different batch sizes and context lengths using our performance model.
We use ’s corresponding to a medium context length of to calculate all throughput multipliers in Figures 1, 4, 5, and 6, by assuming that is the same when generating with shorter or longer contexts. Table 1 shows that the average generation length is similar for contexts of length , and our assumption is supported by e.g. Figure 5 of Xiao et al. (2024), which suggests that even with extremely long contexts, it suffices to attend to the most recent tokens.
Context Length | Output Length | Vanilla SD | MagicDec | SPIRe |
---|---|---|---|---|
960 | 64 | 2.644 ± 0.004 | 3.793 ± 0.004 | 3.382 ± 0.004 |
512 | 64 | 2.647 ± 0.004 | 3.891 ± 0.004 | 3.401 ± 0.004 |
256 | 64 | 2.637 ± 0.004 | 4.005 ± 0.004 | 3.427 ± 0.004 |
Table 1: Average generation lengths , with 95% confidence intervals, for different draft models and context lengths. Using a maximum speculation depth of , each round of speculation generates between 1 and 5 tokens, with between 0 and 4 tokens accepted before the first rejection.
We focus on a medium context of 512 tokens since we care about accelerating decoding across a range of context lengths. If a context of 128K is considered “long” for Llama 3 70B (Llama Team (2024)), then a context of 64K may be considered “medium”. As shown in Figure 3, 64K is four powers of two larger than the longest context for which verification with Llama 3 70B is compute-bound at some batch size. A context length of 512 is four powers of two larger than for our target model, so we consider a context of 512 to be of medium length.
Figure 3: With a sufficiently long context, verification may never be compute-bound even in the limit as batch size tends to infinity. For context lengths where verification is compute bound for some batch size , there’s no benefit in terms of throughput (and some cost in terms of latency) to increasing batch size above .
Results
Our draft model SPIRe produces the largest increase in the modeled throughput of generating tokens for most batch sizes and context lengths, as shown in Figure 4. SPIRe achieves a lower average generation length than MagicDec, but drafting with SPIRe nevertheless produces a larger speedup due to its lower cost of performing forward passes.
Figure 4: Speedup of generating tokens in a batch of size given a context of length . Using the highlighted values, we get that SPIRe increases the modeled throughput of speculative decoding by 100% compared to vanilla speculative decoding and by 35% compared to MagicDec.
Figure 1 is derived from Figure 4 by focusing on a batch size of . For medium contexts of around 512 tokens, Figure 3 shows that the operational intensity of verification barely increases as we increase the verification batch size beyond , but it does increase significantly as the verification batch size is increased from 1 to 64.
Ablations
See Figure 5 for ablations of the techniques used to train our draft model SPIRe. To ablate the sparse KV cache, we conservatively assume the same average generation length as SPIRe, and use our performance model to calculate new throughput multipliers. We ablate the other techniques by training new draft models and evaluating them.
Figure 5: Ablations of the techniques used to train our draft model SPIRe. In order of importance, the techniques are 1) sparse KV cache, 2) attending to target model activations, 3) distillation loss, 4) initializing by pruning the target model, 5) feedback memory, and 6) .
The throughput multiplier for SPIRe (without sparse KV cache) is constant with respect to the context length in Figure 5 since here the memory cost of generating a token is higher than the compute cost, and since and , cancels out in our expression for the iteration time multiplier .
Extrapolation
What if we extrapolate the results above to drafting for Llama 3 70B while maintaining the ratio , by assuming the same values of the average generation length but recalculating iteration time multipliers using our performance model? Our draft model SPIRe produces the largest increase in the modeled throughput of generating tokens for most context lengths supported by LLM providers, as shown in Figure 6.
Figure 6: Estimated speedup of generating tokens from Llama 3 70B in a batch of size given a context of length . This figure is very similar, but not identical, to Figure 1.
Cost Analysis
Each draft model costs a different amount in training FLOPs. We ignore FLOPs due to embedding and unembedding parameters, which become negligible as the draft model is scaled up. The cost to train our draft model SPIRe is around th the cost to train the target model, since it has th as many body parameters but achieves lower MFU during training.
Suppose that training the target model costs 1 unit of our compute budget, and that we expect to spend 10 units generating tokens for customers using autoregressive decoding. Using our draft model SPIRe for speculative decoding with amounts to investing to train the draft model but saving units in decoding cost, and drafting with MagicDec amounts to saving units in decoding cost (the throughput multipliers 2.78 and 2.05 are highlighted in Figure 4). Overall, SPIRe saves units, or over 20% more than what MagicDec saves.
If we again extrapolate our results to drafting for Llama 3 70B with , we find that drafting with SPIRe saves units in decoding cost and drafting with MagicDec saves units (the throughput multipliers 2.82 and 2.14 are highlighted in the Appendix). Overall, SPIRe saves units, or over 16% more than what MagicDec saves. If there is significant downward variation in batch sizes and context lengths, then SPIRe outperforms MagicDec by a much greater amount in terms of decoding cost saved.
Related Work
- The only other work we’re aware of which attempts to use speculative decoding to increase the throughput of decoding, rather than to minimize latency, is MagicDec (Chen et al. (2024)).
- Continuous batching reduces the need for speculative decoding, but both techniques are complementary. In vLLM, speculative decoding is integrated with the system’s continuous batching architecture, and studying their interaction further is an interesting direction for future work.
- Our technique is orthogonal to any technique used to accelerate prefill. In practice, accelerating both prefill and decoding is important for reducing the cost per token generated.
- We used the token verification algorithm in our experiments, but our results should hold when using the block verification (Sun et al. (2024)) algorithm too.
Conclusion
Minimizing cost per token requires maximizing throughput. We demonstrated that draft model architecture significantly impacts throughput, with optimal choices depending on the batch sizes and context lengths expected in production. We proposed an implementation-agnostic performance model for evaluating the throughput of speculative decoding with different draft models, and we proposed a draft model SPIRe which outperforms strong baselines when decoding with large batch sizes and medium-to-long contexts. Future work can empirically validate our performance model, use a more principled long-context evaluation, and analyze the sensitivity of speedup with respect to the maximum speculation depth .
Contributions and Acknowledgements
Sanjit Neelam proposed most aspects of the architecture of our draft model SPIRe. Feedback Transformer was suggested by Reiner Pope, and attending to target activations was co-developed by Sanjit Neelam and Reiner Pope. Explicitly maximizing the expected acceptance probability was proposed by Akshay Mishra. The evaluation methodology was co-designed by Sanjit Neelam and Reiner Pope. All experiments were conducted by Sanjit Neelam, who also wrote all text and produced all figures. Daniel Heinlein and Vaclav Cvicek provided valuable feedback that improved the presentation of results.
We use seqax, our research-focused LLM codebase built on JAX, to perform all experiments. This work was supported by Cloud TPUs from Google’s TPU Research Cloud.
Appendix
See https://github.com/MatX-inc/seqax/blob/SPIRe/spire_appendix.ipynb.
Footnotes
The cost is that speculative decoding always requires performing more FLOPs than autoregressive decoding; the draft model performs forward passes, and the target model performs forward passes on not only the draft tokens that are accepted but also those that are rejected. Thus, speculative decoding only makes sense if autoregressive decoding is sufficiently memory-bound.⤴
Optimized inference stacks with e.g. pipeline parallelism approach this ideal. Throughput multipliers are higher than if we assumed that computation and memory accesses could not be overlapped, in which case we would have .⤴
The inclusion of HOI in the throughput multiplier makes it hardware-dependent, but this only makes a difference in the compute-bound (bottom-left) region of Figure 4. When decoding and verification are both memory-bound, HOI cancels out.⤴
Xiao et al. (2024) use positions within the cache rather than those in the original text when adding positional information to tokens. We do the latter for SPIRe and the former for our implementation of MagicDec.⤴