Transformer architecture generates a new token sequentially based on the previous inputs, repeating the process until the response concludes. Throughout this autoregressive generation, the key and value vectors of each token are repeatedly reused in subsequent steps. Thus, it is computationally efficient to store and reuse the vectors of each token.

However, it increases memory usage. The system needs to store all the key and value vectors of every token for all attention heads across all layers. The number of entries in cache is calculated as:

VariableDescriptionValue (R1)
number of Layers61
Number of attention heads per layer128
Dimension of attention head128
Input tokens100,000

In this setting, 400GB of memory is required, so making it computationally infeasible. To reduce the memory usage, multiple methods are introduced.

Methods

Multi-Query Attention

In multi-query attention, each attention head shares single key and value matrix. The only difference between each attention head is the query matrix. It significantly reduces memory usage, but at the cost of each head’s specialization.

Link to original

Grouped-Query Attention

In grouped-query attention, heads are grouped and each group shares same key value matrix. It’s less destructive than the multi-query attention, but has performance hit relative to the full multi-head attention.

Link to original

Multi-Head Latent Attention

In multi-head latent attention, the input is projected into a low-dimensional latent space, which is then projected back to the key and value matrices by corresponding learnable weight matrices, where the weights are unique to each attention head.

Training Process

In the training process the weight matrices , , are trained to effectively compress the input into and decompress the key and value matrices from the latent space. where

Inference Process

In the inference process, the same expression is rearranged using linear algebra to prevent redundant operations and save computational cost. where

Since the matrices and don’t depend on input, they can be pre-calculated and used as a single matrix respectively.

Link to original

Required Storage Comparison

Attention MechanismKV Cache Entries per TokenKV Cache Size per Token
Multi-Head Attention (MHA)4MB
Multi-Query Attention (MQA)31KB
Grouped-Query Attention (GQA)500KB ()
Multi-Head Latent Attention (MLA)70KB