MatX logo

High throughput chips for LLMs

Optimize for inference too, not just training FLOPs

January 8, 2025. By MatX ML team.

Large Language Models (LLMs) have been shaped by the scaling laws, as established by Kaplan et al. (2020) and Hoffmann et al. (2022). They guide us to design models that optimize for training cost but often overlook inference costs. Although training and inference costs are highly related, they diverge sharply in one critical area: attention key-value (KV) computation. During training, KV computation is usually cheaper than the rest of the model, but during inference, loading the KV cache becomes the dominant expense. This disconnect has led to models with ~70% efficiency during training but only ~10% efficiency during inference.

Perhaps a methodology that considers both inference and training cost would lead to more balanced models? This post sketches out some options in this direction.

Scaling Laws and Model Size Selection

Kaplan et al. (2020) demonstrated that transformer model quality improves predictably with increasing model parameters NN and training tokens DD. They estimated the compute cost of training a transformer model at 6ND6ND floating-point operations (FLOPs). Other resource requirements, such as loading weight matrices from memory, are amortized by using large training batches. Since training is compute-bound, the training FLOPs budget effectively determines the optimal model size and the number of training tokens needed to maximize model performance.

Inference and the Memory Bottleneck

During inference, we distinguish between two phases: prefill and decode.

Prefill initializes the KV cache by processing continuous input text, such as a question or background information. In this phase, the model processes all input tokens in parallel, similar to training. As a result, the cost of loading weight matrices is amortized over the sequence, and computation remains the dominant cost. The compute cost per token is approximately 2N2N FLOPs, since only the forward pass is performed.

In contrast, the decode phase generates text one token at a time. For each generated token, the model still requires 2N2N FLOPs for the forward pass. Furthermore, for each generated token, the model needs to load all parameters and the KV cache associated with all previous tokens. While the loading of model parameters can be amortized by using larger batch sizes, the KV cache grows both with the batch size and the sequence length, making memory bandwidth a potential bottleneck (Pope et al., 2022).

Consider decoding using a transformer model with multi-head attention (MHA) (Vaswani et al., 2017), a batch size of 128, a context length of 8,192, and using a 1-byte numeric type. A typical model with 8 billion parameters will require loading 280 GB of KV cache, i.e., 30 times larger than the model parameters themselves! When running such a model on typical hardware, the arithmetic logic will spend much of its time waiting for the KV cache to load.

Llama 3 Models and Grouped Query Attention

The Llama 3 models (Llama Team, 2024) address the inference memory bottleneck by incorporating Grouped Query Attention (GQA; Ainslie et al., 2023), which significantly reduces the KV cache size by sharing key and value projections across multiple attention heads. For instance, in the 8B model, the KV cache size drops from 280 GB with MHA to 69 GB with GQA. While GQA may cause a slight decrease in model quality compared to MHA, this trade-off is acceptable given the substantial gain in inference speed. The use of GQA in Llama 3 models illustrates how architectural choices can balance compute and memory demands.

The size of the KV cache for different model sizes is shown in the following table:

KV Cache Size
Model Configuration 8B Model 70B Model 405B Model
Transformer with MHA 280 GB 1,400 GB 4,300 GB
Llama 3 (with GQA) 69 GB 170 GB 270 GB

FLOP-Equivalent

To compare computational FLOPs and memory accesses on equal footing, we use the Hardware Operational Intensity (HOI):

HOI=Peak Compute Performance (FLOPs/s)Peak Memory Bandwidth (Bytes/s) \text{HOI} = \frac{\text{Peak Compute Performance (FLOPs/s)}}{\text{Peak Memory Bandwidth (Bytes/s)}}

HOI indicates the hardware’s computational capacity relative to its memory bandwidth—it tells us how many floating-point operations can be performed per byte of memory access. By multiplying the KV cache size (in bytes) by HOI, we obtain the FLOP-equivalent cost required to load the KV cache, allowing us to directly compare memory operations with compute operations. An important aspect of HOI is that it tends to remain relatively stable across hardware generations.

For an Nvidia H100 GPU using FP8 arithmetic, HOI is approximately 600 FLOPs/byte. Let’s revisit the Llama 3 8B model, and consider the cost of decoding the next token in each of 128 sequences with a context length of 8192:

In this case, the cost of loading the KV cache in FLOP-equivalents is 20 times greater than the compute cost of performing the forward pass. This indicates that the 8B model is memory-bound during decoding in this setting.

For the Llama 3 405B model, the ratio of KV cache load cost to compute cost is approximately 1.5, indicating a closer balance between compute and memory demands. However, if we increase the context length, the KV cache size grows proportionally, making decode memory-bound again.

Rethinking Model Selection: Total Cost Optimization

Traditional model selection focuses on training FLOPs, defined by the cost Ctrain=6NDtrainC_{\text{train}} = 6N D_{\text{train}}, but this approach overlooks the significant costs associated with inference. To achieve a model of the same quality that is cheaper over its lifetime, we consider the total estimated lifetime cost, which includes the training, prefill, and decode phases:

Ctotal=6NDtrainTraining cost + 2NDprefillPrefill cost + max(2N,Forward passKV cache size×HOIKV cache loading)×DdecodeDecode cost C_{\text{total}} = \underbrace{6N D_{\text{train}}}_{\text{Training cost}} \ +\ \underbrace{2N D_{\text{prefill}}}_{\text{Prefill cost}} \ +\ \underbrace{\max( \underbrace{2N,}_{\text{Forward pass}} \underbrace{\text{KV cache size} \times \text{HOI}}_{\text{KV cache loading}} ) \times D_{\text{decode}}}_{\text{Decode cost}}

Including inference costs in the total model cost alters the optimal balance between model size and training data. Sardana et al. (2023) showed that when inference costs are considered, it results in recommending smaller models trained on more data compared to the Chinchilla scaling laws (Hoffmann et al., 2022), which suggest Dtrain20ND_{\text{train}} \approx 20N. Besides training on more tokens, our approximation of the total computational cost can be used to explore architectural changes to the transformer model that spend more FLOPs to decrease memory bandwidth cost during inference.

Strategies to Spend FLOPs to Save Memory Bandwidth

Our proposed cost metric allows for a comparison of transformer variants by considering both computational and memory-bandwidth demands. While memory-efficient attention mechanisms have been extensively researched, these techniques are designed to save both on FLOPs and on memory. Such free lunches are hard to find; we propose approaches which increase FLOPs cost in order to save memory, so long as the total lifetime cost of the model is improved. Our research agenda includes the following:

  1. KV Cache Compression Techniques

Extreme compression techniques for the KV cache, such as Multi-Query Attention (Shazeer, 2019), cross-layer KV cache sharing (Brandon et al., 2024; Character AI), and cross-token KV cache sharing (Mu et al., 2023) are often rejected due to quality degradations when compared on a FLOPs-neutral basis. However, on a total-cost-neutral basis, these are very likely massively profitable. Further techniques in this direction are worth exploring.

  1. Changing Model Shape

The typical practice for model shape (depth vs width; dmodeld_{\text{model}} vs dffd_{\text{ff}} vs nheadsn_{\text{heads}}) is optimized for quality given a training budget. When optimizing for quality given a lifetime budget, the optimal shape is likely different:

  1. Speculative Decoding for Large Batches and Contexts

While Speculative Decoding (Chen et al., 2023; Leviathan et al., 2023) is traditionally considered a small-batch-size optimization, it becomes beneficial for large batch sizes when considering KV cache memory bandwidth (Chen et al., 2024). Techniques that incorporate Speculative Decoding during pretraining, rather than as a post-training optimization, seem promising.

  1. Enhanced Query-Key Interactions

Dot-Product Attention is designed to make each query-key interaction highly computationally efficient, thus keeping training costs low. However, for inference, where memory fetches are more expensive than computation in query-key interactions, more computationally complex query-key interactions may be profitable. This could include algebraically different interactions (Kobayashi et al., 2020), or even deeper neural networks for these interactions.

By optimizing for a cost metric that aligns more closely with what we truly care about, we can pursue model designs that pay sufficient respect to the significant cost of the KV cache. Models which are 1.5x more expensive to train and 10x cheaper to inference might easily be possible.