In multi-head latent attention, the input is projected into a low-dimensional latent space, which is then projected back to the key and value matrices by corresponding learnable weight matrices, where the weights are unique to each attention head.

Training Process

In the training process the weight matrices , , are trained to effectively compress the input into and decompress the key and value matrices from the latent space. where

Inference Process

In the inference process, the same expression is rearranged using linear algebra to prevent redundant operations and save computational cost. where

Since the matrices and don’t depend on input, they can be pre-calculated and used as a single matrix respectively.