Transformer architecture generates a new token sequentially based on the previous inputs, repeating the process until the response concludes. Throughout this autoregressive generation, the key and value vectors of each token are repeatedly reused in subsequent steps. Thus, it is computationally efficient to store and reuse the vectors of each token.
However, it increases memory usage. The system needs to store all the key and value vectors of every token for all attention heads across all layers. The number of entries in cache is calculated as:
| Variable | Description | Value (R1) | |
|---|---|---|---|
| number of Layers | 61 | ||
| Number of attention heads per layer | 128 | ||
| Dimension of attention head | 128 | ||
| Input tokens | 100,000 |
In this setting, 400GB of memory is required, so making it computationally infeasible. To reduce the memory usage, multiple methods are introduced.
Methods
Multi-Query Attention
Link to original
In multi-query attention, each attention head shares single key and value matrix. The only difference between each attention head is the query matrix. It significantly reduces memory usage, but at the cost of each head’s specialization.
Grouped-Query Attention
Link to original
In grouped-query attention, heads are grouped and each group shares same key value matrix. It’s less destructive than the multi-query attention, but has performance hit relative to the full multi-head attention.
Multi-Head Latent Attention
In multi-head latent attention, the input is projected into a low-dimensional latent space, which is then projected back to the key and value matrices by corresponding learnable weight matrices, where the weights are unique to each attention head.
Training Process
In the training process the weight matrices , , are trained to effectively compress the input into and decompress the key and value matrices from the latent space. where
Inference Process
In the inference process, the same expression is rearranged using linear algebra to prevent redundant operations and save computational cost. where
Since the matrices and don’t depend on input, they can be pre-calculated and used as a single matrix respectively.
Link to original
Required Storage Comparison
| Attention Mechanism | KV Cache Entries per Token | KV Cache Size per Token |
|---|---|---|
| Multi-Head Attention (MHA) | 4MB | |
| Multi-Query Attention (MQA) | 31KB | |
| Grouped-Query Attention (GQA) | 500KB () | |
| Multi-Head Latent Attention (MLA) | 70KB |
In multi-query attention, each attention head shares single key and value matrix. The only difference between each attention head is the query matrix. It significantly reduces memory usage, but at the cost of each head’s specialization.
In grouped-query attention, heads are grouped and each group shares same key value matrix. It’s less destructive than the multi-query attention, but has performance hit relative to the full multi-head attention.
In multi-head latent attention, the input is projected into a low-dimensional latent space, which is then projected back to the key and value matrices by corresponding learnable weight matrices, where the weights are unique to each attention head.
In the training process the weight matrices
In the inference process, the same expression is rearranged using linear algebra to prevent redundant operations and save computational cost.