Optimize for inference too, not just training FLOPs
Large Language Models (LLMs) have been shaped by the scaling laws, as established by Kaplan et al. (2020) and Hoffmann et al. (2022). They guide us to design models that optimize for training cost but often overlook inference costs. Although training and inference costs are highly related, they diverge sharply in one critical area: attention key-value (KV) computation. During training, KV computation is usually cheaper than the rest of the model, but during inference, loading the KV cache becomes the dominant expense. This disconnect has led to models with ~70% efficiency during training but only ~10% efficiency during inference.
Perhaps a methodology that considers both inference and training cost would lead to more balanced models? This post sketches out some options in this direction.
Scaling Laws and Model Size Selection
Kaplan et al. (2020) demonstrated that transformer model quality improves predictably with increasing model parameters and training tokens . They estimated the compute cost of training a transformer model at floating-point operations (FLOPs). Other resource requirements, such as loading weight matrices from memory, are amortized by using large training batches. Since training is compute-bound, the training FLOPs budget effectively determines the optimal model size and the number of training tokens needed to maximize model performance.
Inference and the Memory Bottleneck
During inference, we distinguish between two phases: prefill and decode.
Prefill initializes the KV cache by processing continuous input text, such as a question or background information. In this phase, the model processes all input tokens in parallel, similar to training. As a result, the cost of loading weight matrices is amortized over the sequence, and computation remains the dominant cost. The compute cost per token is approximately FLOPs, since only the forward pass is performed.
In contrast, the decode phase generates text one token at a time. For each generated token, the model still requires FLOPs for the forward pass. Furthermore, for each generated token, the model needs to load all parameters and the KV cache associated with all previous tokens. While the loading of model parameters can be amortized by using larger batch sizes, the KV cache grows both with the batch size and the sequence length, making memory bandwidth a potential bottleneck (Pope et al., 2022).
Consider decoding using a transformer model with multi-head attention (MHA) (Vaswani et al., 2017), a batch size of 128, a context length of 8,192, and using a 1-byte numeric type. A typical model with 8 billion parameters will require loading 280 GB of KV cache, i.e., 30 times larger than the model parameters themselves! When running such a model on typical hardware, the arithmetic logic will spend much of its time waiting for the KV cache to load.
Llama 3 Models and Grouped Query Attention
The Llama 3 models (Llama Team, 2024) address the inference memory bottleneck by incorporating Grouped Query Attention (GQA; Ainslie et al., 2023), which significantly reduces the KV cache size by sharing key and value projections across multiple attention heads. For instance, in the 8B model, the KV cache size drops from 280 GB with MHA to 69 GB with GQA. While GQA may cause a slight decrease in model quality compared to MHA, this trade-off is acceptable given the substantial gain in inference speed. The use of GQA in Llama 3 models illustrates how architectural choices can balance compute and memory demands.
The size of the KV cache for different model sizes is shown in the following table:
KV Cache Size | |||
---|---|---|---|
Model Configuration | 8B Model | 70B Model | 405B Model |
Transformer with MHA | 280 GB | 1,400 GB | 4,300 GB |
Llama 3 (with GQA) | 69 GB | 170 GB | 270 GB |
FLOP-Equivalent
To compare computational FLOPs and memory accesses on equal footing, we use the Hardware Operational Intensity (HOI):
HOI indicates the hardware’s computational capacity relative to its memory bandwidth—it tells us how many floating-point operations can be performed per byte of memory access. By multiplying the KV cache size (in bytes) by HOI, we obtain the FLOP-equivalent cost required to load the KV cache, allowing us to directly compare memory operations with compute operations. An important aspect of HOI is that it tends to remain relatively stable across hardware generations.
For an Nvidia H100 GPU using FP8 arithmetic, HOI is approximately 600 FLOPs/byte. Let’s revisit the Llama 3 8B model, and consider the cost of decoding the next token in each of 128 sequences with a context length of 8192:
- Forward pass computation cost:
- Cost to load the full KV cache:
In this case, the cost of loading the KV cache in FLOP-equivalents is 20 times greater than the compute cost of performing the forward pass. This indicates that the 8B model is memory-bound during decoding in this setting.
For the Llama 3 405B model, the ratio of KV cache load cost to compute cost is approximately 1.5, indicating a closer balance between compute and memory demands. However, if we increase the context length, the KV cache size grows proportionally, making decode memory-bound again.
Rethinking Model Selection: Total Cost Optimization
Traditional model selection focuses on training FLOPs, defined by the cost , but this approach overlooks the significant costs associated with inference. To achieve a model of the same quality that is cheaper over its lifetime, we consider the total estimated lifetime cost, which includes the training, prefill, and decode phases:
Training cost: The number of FLOPs performed during training, including forward and backward passes.
Prefill cost: The number of FLOPs performed during the prefill phase of inference, forward pass only.
Decode cost: The dominant cost between computation and memory accesses during the decode phase, multiplied by the number of decode tokens.
Including inference costs in the total model cost alters the optimal balance between model size and training data. Sardana et al. (2023) showed that when inference costs are considered, it results in recommending smaller models trained on more data compared to the Chinchilla scaling laws (Hoffmann et al., 2022), which suggest . Besides training on more tokens, our approximation of the total computational cost can be used to explore architectural changes to the transformer model that spend more FLOPs to decrease memory bandwidth cost during inference.
Strategies to Spend FLOPs to Save Memory Bandwidth
Our proposed cost metric allows for a comparison of transformer variants by considering both computational and memory-bandwidth demands. While memory-efficient attention mechanisms have been extensively researched, these techniques are designed to save both on FLOPs and on memory. Such free lunches are hard to find; we propose approaches which increase FLOPs cost in order to save memory, so long as the total lifetime cost of the model is improved. Our research agenda includes the following:
- KV Cache Compression Techniques
Extreme compression techniques for the KV cache, such as Multi-Query Attention (Shazeer, 2019), cross-layer KV cache sharing (Brandon et al., 2024; Character AI), and cross-token KV cache sharing (Mu et al., 2023) are often rejected due to quality degradations when compared on a FLOPs-neutral basis. However, on a total-cost-neutral basis, these are very likely massively profitable. Further techniques in this direction are worth exploring.
- Changing Model Shape
The typical practice for model shape (depth vs width; vs vs ) is optimized for quality given a training budget. When optimizing for quality given a lifetime budget, the optimal shape is likely different:
It likely grows the feedforward network and shrinks the attention network, reducing the KV-cache memory bandwidth.
It likely grows the width (especially ) and decreases the depth. This saves memory capacity: while keeping high efficiency, you can run a smaller batch size on the same number of chips—or equivalently, run the same batch size on a larger number of chips.
- Speculative Decoding for Large Batches and Contexts
While Speculative Decoding (Chen et al., 2023; Leviathan et al., 2023) is traditionally considered a small-batch-size optimization, it becomes beneficial for large batch sizes when considering KV cache memory bandwidth (Chen et al., 2024). Techniques that incorporate Speculative Decoding during pretraining, rather than as a post-training optimization, seem promising.
- Enhanced Query-Key Interactions
Dot-Product Attention is designed to make each query-key interaction highly computationally efficient, thus keeping training costs low. However, for inference, where memory fetches are more expensive than computation in query-key interactions, more computationally complex query-key interactions may be profitable. This could include algebraically different interactions (Kobayashi et al., 2020), or even deeper neural networks for these interactions.
By optimizing for a cost metric that aligns more closely with what we truly care about, we can pursue model designs that pay sufficient respect to the significant cost of the KV cache. Models which are 1.5x more expensive to train and 10x cheaper to inference might easily be possible.