Scaled dot-product attention
Head mapping variants
Standard multi-head scaled-dot-product attention uses the same number of query, key, and value heads. The -th head of the queries always interacts with the -th head of the keys to calculate an attention matrix and that attention matrix is used with the -th head of the values to compute new representations.
Multi-Query Attention
There have been two commonly-used approaches to move away from a 1:1:1 mapping. Multi-Query Attention (Shazeer 2019) uses just a single key and a single value head. When computing in attention, the query is broadcast. Example snippet that simulates the shapes:
The value is broadcast in a similar manner as well. The main motivation for reducing the number of key/value heads is that decoding is memory bandwidth-bound. Using a single key/value head cuts down data transfer.
Grouped-Query Attention
Grouped-Query Attention (Joshua Ainslie et al., 2023) is a similar optimization, but rather than using a single head, but uses a 1:m mapping from key/value heads to query heads. For instance, if a model uses 8 query heads and 4 key-value heads, each key-value head maps/broadcasts to 2 query heads. This is a compromise between the memory bandwidth optimization of Multi-Query Attention and the more powerful Multi-Head Attention.
Logit transformations
Softcapping
Some models (e.g. Gemma2 and Grok) use soft-capping of attention logits. See Logit soft-capping.