Multi-Query & Grouped-Query Attention

Multi-Head Self-Attention

Recall that the standard self-attention mechanism in Transformer models1 computes an attention score between every pair of tokens in an input sequence of text. Internally, each token in the input sequence, represented by an embedding (that is, a vector of floating point values), is first projected into a query and key vector.

For example, consider an input sequence of 5 tokens, such as the phrase Hope is a waking dream, with each token represented by an embedding of dimension 8, forming a \(5 \times 8\) matrix.

This matrix of input embeddings is passed separately through the Query Layer and Key Layer (each a linear layer), producing a query and key vector for each token:

isawakingdreamHopeQueryVectorsQuery LayerWeightsKeyLayerWeightsKeyVectors

Attention scores are then computed by multiplying the query vectors by the transpose of the key vectors; or, put another way, the attention scores are computed by taking the dot product of each query vector with each key vector.

QueryVectorsKeyVectorsAttentionScores

The resulting matrix of attention scores can be thought of as representing the strength of some relationship between each pair of tokens in the sequence.

HopeisawakingdreamHopeisawakingdream

The scores above are representative of bidirectional attention mechanism, where each token can “attend” to every other token in the input sequence.

The focus of this page is on models that employ a unidirectional or autoregressive attention mechanism, where each token can only attend to those that came before it in the input sequence. Most transformer models used for text generation tasks use an autoregressive attention mechanism.

Thus, the attention scores matrix for an autoregressive decoder model will look like the one below, where green squares represent non-zero values, and grey cells represent zero values.

HopeisawakingdreamHopeisawakingdream

Transformer models typically contain several unique attention heads (hence the name “multi-head attention”), each of which computes its own attention scores matrix. The full picture, then, looks something like this:

Head 1Head 2Head n

We can see the multi-head self-attention mechanism in action using the HuggingFace Transformers2 library.

The code below loads GPT23, an autoregressive transformer model, and prints the attention scores computed by the first layer on the input string Hope is a waking dream:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_printoptions(sci_mode=False, precision=4)

gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")

gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")

with torch.no_grad():
    inputs = gpt2_tokenizer("Hope is a waking dream", return_tensors="pt", add_special_tokens=False)
    attentions = gpt2(inputs.input_ids, output_attentions=True).attentions

    # get the attention scores computed by the first layer
    # for the first input sequence in the batch
    first_layer_attentions = attentions[0][0]

    # print attention scores from the first head
    print("GPT2 Attention Scores (Head 1):")
    print(first_layer_attentions[0])
GPT2 Attention Scores (Head 1):
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9605, 0.0395, 0.0000, 0.0000, 0.0000],
        [0.8410, 0.1055, 0.0534, 0.0000, 0.0000],
        [0.4271, 0.1821, 0.1070, 0.2838, 0.0000],
        [0.1741, 0.0415, 0.0308, 0.7006, 0.0530]])

For each of the five tokens in the input, notice how attention scores are computed between that token and all tokens that precede it. Consider the fourth input token, waking, corresponding to the fourth row in the attention scores matrix above. Attention scores are computed between waking and all other tokens in the input sequence except for dream:

0.430.180.110.28wakingHopeisawakingdream

Importantly, this means that the attention scores between waking and all the tokens that precede it (i.e. the fourth row in the matrix) will be the same regardless of what word or words come after waking in the input sequence.

For example, if we compute attention scores for the input sequence Hope is a waking sandwich, the first four rows of the attention scores matrix will be the same as the first four rows of the matrix above:

inputs = gpt2_tokenizer("Hope is a waking sandwich", return_tensors="pt", add_special_tokens=False)
attentions = gpt2(inputs.input_ids, output_attentions=True).attentions

# get the attention scores computed by the first layer
# for the first input sequence in the batch
first_layer_attentions = attentions[0][0]

# print attention scores from the first head
print("GPT2 Attention Scores (Head 1):")
print(first_layer_attentions[0])
GPT2 Attention Scores (Head 1):
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9605, 0.0395, 0.0000, 0.0000, 0.0000],
        [0.8410, 0.1055, 0.0534, 0.0000, 0.0000],
        [0.4271, 0.1821, 0.1070, 0.2838, 0.0000],
        [0.3681, 0.2516, 0.1510, 0.0601, 0.1692]])

Furthermore, if we add another token to the end of the input sequence, a new row and column will be added to the attention scores matrix, and all of the preceding rows will remain unchanged (except for the addition of a zero in the new column):

inputs = gpt2_tokenizer("Hope is a waking sandwich!", return_tensors="pt", add_special_tokens=False)
attentions = gpt2(inputs.input_ids, output_attentions=True).attentions

# get the attention scores computed by the first layer
# for the first input sequence in the batch
first_layer_attentions = attentions[0][0]

# print attention scores from the first head
print("GPT2 Attention Scores (Head 1):")
print(first_layer_attentions[0])
GPT2 Attention Scores (Head 1):
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9605, 0.0395, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8410, 0.1055, 0.0534, 0.0000, 0.0000, 0.0000],
        [0.4271, 0.1821, 0.1070, 0.2838, 0.0000, 0.0000],
        [0.3681, 0.2516, 0.1510, 0.0601, 0.1692, 0.0000],
        [0.3929, 0.0782, 0.0524, 0.1585, 0.2443, 0.0737]])

This characteristic of autoregressive attention is important to keep in mind when we consider the memory requirements of transformer models.

Attention in the Text Generation Loop

Autoregressive Transformer models generate text one token at a time. The code below implements a very basic text generation loop for GPT2. At each iteration of the loop, we call the model, take the highest probability token using torch.argmax, and append that token to the original input sequence.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")

with torch.no_grad():

    num_new_tokens = 3

    # tokenize the original input sentence
    inputs = gpt2_tokenizer("Hope is a", return_tensors="pt", add_special_tokens=False)

    # we execute this loop until we generate num_new_tokens new tokens
    for i in range(num_new_tokens):
        print("-"*50)
        print(f"Iteration {i+1}")
        print(f"Generating the next token for sequence: '{gpt2_tokenizer.decode(inputs.input_ids[0])}'")
        
        outputs = gpt2(inputs.input_ids, output_attentions=True)

        # take the highest scoring token as the next token
        logits = outputs.logits
        next_token = torch.argmax(logits[:, -1, :])

        # concatenate the new token with the rest of the input
        inputs.input_ids = torch.cat([inputs.input_ids, next_token.reshape([1,1])], dim=-1)
        
        print("Generated token:", gpt2_tokenizer.decode(next_token))
    
    print("-"*50)
    print(f"Final generated sequence: '{gpt2_tokenizer.decode(inputs.input_ids[0])}'")
--------------------------------------------------
Iteration 1
Generating the next token for sequence: 'Hope is a'
Generated token:  great
--------------------------------------------------
Iteration 2
Generating the next token for sequence: 'Hope is a great'
Generated token:  way
--------------------------------------------------
Iteration 3
Generating the next token for sequence: 'Hope is a great way'
Generated token:  to
--------------------------------------------------
Final generated sequence: 'Hope is a great way to'

If we examine the attention scores computed by any attention head at each iteration of the loop, we will see that each new token results in a new row and column in the attention scores matrix; and, as illustrated above, all of the preceding rows remain unchanged.

for i in range(num_new_tokens):
    
    ...

    # get the attention scores computed by the first layer and first head
    first_layer_attentions = outputs.attentions[0][0]
    print("Attention Scores (Head 1):")
    print(first_layer_attentions[0])
--------------------------------------------------
Iteration 1
Generating the next token for sequence: 'Hope is a'
Generated token:  great
Attention Scores (Head 1):
tensor([[1.0000, 0.0000, 0.0000],
        [0.9605, 0.0395, 0.0000],
        [0.8410, 0.1055, 0.0534]])
--------------------------------------------------
Iteration 2
Generating the next token for sequence: 'Hope is a great'
Generated token:  way
Attention Scores (Head 1):
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.9605, 0.0395, 0.0000, 0.0000],
        [0.8410, 0.1055, 0.0534, 0.0000],
        [0.6815, 0.1267, 0.1025, 0.0894]])
--------------------------------------------------
Iteration 3
Generating the next token for sequence: 'Hope is a great way'
Generated token:  to
Attention Scores (Head 1):
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9605, 0.0395, 0.0000, 0.0000, 0.0000],
        [0.8410, 0.1055, 0.0534, 0.0000, 0.0000],
        [0.6815, 0.1267, 0.1025, 0.0894, 0.0000],
        [0.7019, 0.0803, 0.0513, 0.1099, 0.0566]])
--------------------------------------------------
Final generated sequence: 'Hope is a great way to'

What this means is that, at each iteration of the loop, there is no need to recompute the attention scores produced by the previous iteration. All we really need is to compute the last row of the attention scores matrix. This can then be appended to the attention scores matrix from the previous iteration.

The Key Cache

Recall that attention scores are computed by taking the dot products of query vectors and key vectors. To compute only the last row of the attention scores matrix, we need only compute the dot products of the last query vector with all of the key vectors.

We can visualize this like below, with empty square representing parts of the computation that we can reuse from previous iterations:

QueryVectorsKeyVectorsAttentionScores

From this, we can see that only the last query vector and all of the key vectors are needed to compute the last row of the attention scores matrix. The key vectors themselves are computed by multiplying the input embeddings by the key layer weights as we saw earlier:

KeyLayerWeightsInput EmbeddingsKeyVectors

Thus, at each iteration, we need only compute the last key vector (as this depends on the input embedding of the most recent token), while all others can be reused from the previous iteration.

We can save a lot of redundant computation by maintaining a Key Cache, which stores the key vectors computed at each iteration. First, we compute just a single query vector and a single key vector:

isagreatwayHopeQueryVectorsQuery LayerWeightsKeyLayerWeightsKeyVectors

Then, we pull the previously computed key vectors from the Key Cache and compute last row of the attention scores matrix as the dot product of the new query vector with each key vector:

QueryVectorsKeyVectorsKeyCacheAttentionScores

The Value Cache

So far, we have seen how caching key vectors can eliminate redundant attention score computation at each iteration of a text generation loop. The value vectors in multi-head self-attention can similarly be cached at each iteration.

Recall that the value vectors are computed by sending the input embeddings through a third linear layer (which we’ll unimaginatively call the Value Layer):

ValueLayerWeightsValueVectorsInput Embeddings

We then “re-weight” the value vectors based on the attention scores by multiplying the two matrices together:

ReweightedValueVectorsValueVectorsAttentionScores

Just like with key vectors, only the last (i.e. the most recent) value vector needs to be computed at each iteration. All the other value vectors can be pulled from a Value Cache and reused:

ReweightedValueVectorsValueVectorsAttentionScoresValueCache

The KV Cache, Computation, and Memory

The Key Cache and Value Cache are referred to collectively as the KV Cache, and offer such a significant reduction in computation that they are enabled by default in the transformers library.

The KV Cache can be disabled by setting use_cache=False when loading the model. The code below demonstrates the difference in computation time between the two settings. Even for the small 124M parameter GPT2 model, the KV Cache offers a ~6x speedup:

import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def generate_tokens(use_kv_cache):

    gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", use_cache=use_kv_cache)
    gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")

    with torch.no_grad():

        num_new_tokens = 500

        # tokenize the original input sentence
        inputs = gpt2_tokenizer("Hope is a", return_tensors="pt", add_special_tokens=False)

        start_time = time.time()
        gpt2.generate(**inputs, max_new_tokens=num_new_tokens, min_new_tokens=num_new_tokens)
        end_time = time.time()

        print(f"Time taken to generate {num_new_tokens} tokens: {end_time - start_time:.4f} seconds")
        print(f"Time taken per token: {(end_time - start_time)/num_new_tokens:.4f} seconds")


# measure latency with key-value caching disabled
print("Without key-value caching:")
generate_tokens(use_kv_cache=False)

# measure latency with key-value caching enabled
print("\nWith key-value caching:")
generate_tokens(use_kv_cache=True)
Without key-value caching:
Time taken to generate 500 tokens: 62.2500 seconds
Time taken per token: 0.1245 seconds

With key-value caching:
Time taken to generate 500 tokens: 10.3225 seconds
Time taken per token: 0.0206 seconds

We can examine the state of the KV Cache at each iteration of the text generation loop by printing the past_key_values attribute of the model outputs:

from transformers import AutoModelForCausalLM, AutoTokenizer

gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")

gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")

inputs = gpt2_tokenizer("Hope is a waking dream", return_tensors="pt", add_special_tokens=False)
outputs = gpt2(inputs.input_ids, output_attentions=True)

# get the key and value caches of by the first decoder layer
key_cache, value_cache = outputs.past_key_values[0]

# print the shape of the key and value caches
print("Key cache shape:", key_cache.shape)
print("Value cache shape:", value_cache.shape)
Key cache shape: torch.Size([1, 12, 5, 64])
Value cache shape: torch.Size([1, 12, 5, 64])

A separate Key and Value Cache are maintained for each layer in the model, and the four dimensions in each Cache correspond to the batch size, number of attention heads, number of tokens in the input sequence, and the size of the key and value vectors, respectively.

The total size in bytes of the KV Cache, therefore, is given by:

$$l \times b \times n \times h \times s \times 2 \times 2$$

Given the following:

$$\begin{aligned} l &: \text{number of layers} \\ b &: \text{batch size} \\ n &: \text{number of attention heads} \\ h &: \text{attention head size} \\ s &: \text{sequence length} \end{aligned}$$

where the final terms, \(2\) and \(2\), correspond to the number of caches per layer (i.e. one for keys, one for values) and the number of bytes per floating point value (assuming each model parameter is a 16-bit floating point value), respectively.

For GPT2, this comes out to a modest ~36 MB assuming we use the max sequence length of 1024 tokens and a batch size of 1. For much larger models, however, the KV Cache can take up a massive amount of memory.

For example, the table below shows the size of the KV Cache for various versions of GPT-35. Try modifying the batch size and sequence length to see how the size of the KV Cache changes.

Model Parameter Count KV Cache Size
GPT-3 Small 125M
GPT-3 Medium 350M
GPT-3 Large 760M
GPT-3 XL 1.3B
GPT-3 2.7B 2.7B
GPT-3 6.7B 6.7B
GPT-3 13B 13B
GPT-3 175B

As model sizes get larger and sequence lengths get longer, the amount of memory consumed by the KV Cache continues to grow. Multi-Query Attention4 (MQA) and Grouped-Query Attention (GQA) are two techniques that combat this problem.

Multi-Query Attention (MQA)

In Multi-Head Attention, each attention head computes its own unique set of query, key, and value vectors. In Multi-Query Attention, only the query vectors are unique to each head, while the key and value vectors are shared across all heads.

To better illustrate this distinction, consider first the following naive implementation of standard Multi-Head Attention:

class MultiHeadAttentionScores(nn.Module):

    def __init__(self, hidden_size, num_attention_heads, attention_head_size):
        super(MultiHeadAttentionScores, self).__init__()
        self.num_attention_heads = num_attention_heads
        
        # Create a query, key, and value projection layer
        # for each attention head.
        self.query_layers = nn.ModuleList([
            nn.Linear(hidden_size, attention_head_size) 
            for _ in range(num_attention_heads)
        ])
        
        self.key_layers = nn.ModuleList([
            nn.Linear(hidden_size, attention_head_size) 
            for _ in range(num_attention_heads)
        ])
        
        self.value_layers = nn.ModuleList([
            nn.Linear(hidden_size, attention_head_size) 
            for _ in range(num_attention_heads)
        ])

    def forward(self, hidden_states):
        # Create a list to store the outputs of each attention head
        all_attention_outputs = []

        for i in range(self.num_attention_heads):
            query_vectors = self.query_layers[i](hidden_states)
            key_vectors = self.key_layers[i](hidden_states)
            value_vectors = self.value_layers[i](hidden_states)
            
            attention_scores = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
            attention_outputs = torch.matmul(attention_scores, value_vectors)
            all_attention_outputs.append(attention_outputs)

        return all_attention_outputs

At initialization, we create separate Query, Key, and Value layers for each attention head; then, in the forward pass, we iterate over each attention head and compute the outputs for that head.

Contrast this with a similarly naive (but hopefully illustrative) implementation of Multi-Query Attention:

class MultiQueryAttention(nn.Module):

    def __init__(self, hidden_size, num_attention_heads, attention_head_size):
        super(MultiQueryAttention, self).__init__()
        self.num_attention_heads = num_attention_heads
        
        # Create a query layer for each attention head.
        self.query_layers = nn.ModuleList([
            nn.Linear(hidden_size, attention_head_size) 
            for _ in range(num_attention_heads)
        ])
        
        # Create a single key layer and a single value layer
        # that will be shared by all attention heads.
        self.key_layer = nn.Linear(hidden_size, attention_head_size)
        self.value_layer = nn.Linear(hidden_size, attention_head_size)

    def forward(self, hidden_states):
        
        # Create a list to store the outputs of each attention head
        all_attention_outputs = []

        for i in range(self.num_attention_heads):
            query_vectors = self.query_layers[i](hidden_states)
            key_vectors = self.key_layer(hidden_states)
            value_vectors = self.value_layer(hidden_states)
            
            attention_scores = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
            attention_outputs = torch.matmul(attention_scores, value_vectors)
            all_attention_outputs.append(attention_outputs)

        return all_attention_outputs

Here, we create only one Key Layer and one Value Layer, which are shared by all attention heads. The Key Vectors and Value Vectors computed during the forward pass are thus identical across every head.

Because each head uses the same Key and Value Vectors, we can cut down the size of the KV Cache by a factor of \(n\), where \(n\) is the number of attention heads. The size of the KV Cache, when using Multi-Query Attention, is computed as:

$$l \times b \times s \times h \times 2 \times 2$$

Given:

$$\begin{aligned} l &: \text{number of layers} \\ b &: \text{batch size} \\ h &: \text{attention head size} \\ s &: \text{sequence length} \end{aligned}$$

Let’s look again at the size of the KV Cache for various versions of GPT-3, this time with a new column showing the size of the KV Cache when using Multi-Query Attention:

Model Parameter Count KV Cache Size (MHA) KV Cache Size (MQA)
GPT-3 Small 125M
GPT-3 Medium 350M
GPT-3 Large 760M
GPT-3 XL 1.3B
GPT-3 2.7B 2.7B
GPT-3 6.7B 6.7B
GPT-3 13B 13B
GPT-3 175B

The Pitfalls of Multi-Query Attention

Multi-Query Attention offers a major reduction in KV Cache size and was even adopted in the PaLM6 model from Google Research. This technique, however, has since been surpassed in popularity by the similar Grouped-Query Attention technique (more on this in a moment), primarily for 2 reasons.

1. Model Performance Degradation

As with any reduction in the number of model parameters, the switch from standard Multi-Head Attention to Multi-Query Attention comes with some degradation in model performance.

Ainslie, Joshua, et al.7 show the performance degradation of T5-XXL8 on various summarization tasks when using Multi-Query Attention instead of standard Multi-Head Attention:

Even when the overall model parameter count is held constant, there is evidence that Multi-Query Attention still underperforms standard Multi-Head Attention.

In experiments with their Llama 2 model, Touvron, Hugo, et al.9 compare MHA and MQA variants of their model, where the size of the feed-forward networks in the MQA variant is increased to compensate for the reduction in parameters.

They still observe a lower accuracy on average when using Multi-Query Attention across several benchmark datasets:

2. Inefficient Parallelization

Many large transformer models require more memory than can fit on a single GPU, and so some form of model parallelism is required for training and inference. Tensor Parallelism10 is the technique of choice for parallelizing large text generation models.

In brief, Tensor Parallelism works by assigning different attention heads to different GPUs. For example, the 175B GPT-3 model has 96 attention heads. We could parallelize this model across 8 GPUs by assigning 12 attention heads to each.

Each GPU performs the entire attention computation for its assigned heads, then all GPUs exchange the results of their computations.

With Multi-Query Attention, each attention head operates on the same set of Key Vectors and Value Vectors. This means that we will be computing and caching the same set of Key and Value Vectors on each GPU.

Grouped-Query Attention

Grouped-Query Attention7 (GQA) is essentially a generalized definition that encompasses both standard Multi-Head Attention and Multi-Query Attention.

In Multi-Head Attention, the number of unique Key and Value vectors is equal to the number of attention heads; in Multi-Query Attention, the number of unique Key and Value vectors is equal to 1.

In Grouped-Query Attention, the number of unique Key and Value vectors is equal to a hyperparameter \(G\), the number of Groups. For example, if the number of attention heads is 4 and \(G = 2\), then there will be two unique sets of Key and Value vectors, each of which will be used by two attention heads:

Using GQA with a group size of 8, Ainslie, Joshua, et al.7 measure the performance of T5-XXL on the same summarization tasks as above, comparing the results to both Multi-Head Attention and Multi-Query Attention:

As we can see, GQA offers a slight improvement in performance over Multi-Query Attention, while still offering a significant reduction in KV Cache size. In fact, when operating in a multi-GPU environment with tensor parallelism, we can essentially get these performance gains for free by setting \(G\) equal to the number of GPUs.

The table below again shows the size of the KV Cache for the variants of GPT-3, this time with GQA included:

Model Parameter Count KV Cache Size (MHA) KV Cache Size (MQA) KV Cache Size (GQA)
GPT-3 Small 125M
GPT-3 Medium 350M
GPT-3 Large 760M
GPT-3 XL 1.3B
GPT-3 2.7B 2.7B
GPT-3 6.7B 6.7B
GPT-3 13B 13B
GPT-3 175B

Wrapping Up

When running large text generation models, preserving GPU memory is crucial, as it allows us to process larger batches of data and/or longer sequences of text. One of the biggest memory hogs in transformer models is the Key-Value Cache, which stores the key and value vectors computed at each iteration of the text generation loop.

Multi-Query Attention, and its more general form Grouped-Query Attention, are two ways to reduce the size of the KV Cache with minimal impact on model performance. Since its introduction in October 2023, GQA has been adopted in several popular open-source transformer models:

  • The Falcon11 series of models uses Grouped-Query Attention with 1 group (i.e. Multi-Query Attention) for their 7B parameter model, and 8 groups for their 40B and 175B parameter models.

  • The Llama 2 Models employ GQA with 1 group in the 7B and 13B variants, and 8 groups in the 70B variant.

  • The Mistral12 7B model and Mixtral 8x7B13 each use GQA with 8 groups.


References

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, Ł. and Polosukhin, I., 2017. Attention is all you need. Advances in neural information processing systems, 30.

  2. Wolf, T., Debut, L., Sanh, V., Chaumond, J., Delangue, C., Moi, A., Cistac, P., Ma, C., Jernite, Y., Plu, J., Xu, C., Le Scao, T., Gugger, S., Drame, M., Lhoest, Q., & Rush, A. M. (2020). Transformers: State-of-the-Art Natural Language Processing [Conference paper]. 38–45. https://www.aclweb.org/anthology/2020.emnlp-demos.6

  3. Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., and Sutskever, I. 2019. Language Models are Unsupervised Multitask Learners.

  4. Brown, T.B., Mann, B., Ryder, N., Subbiah, M., Kaplan, J., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A. and Agarwal, S., 2020. Language models are few-shot learners. arXiv preprint arXiv:2005.14165.

  5. Shazeer, N., 2019. Fast transformer decoding: One write-head is all you need. arXiv preprint arXiv:1911.02150.

  6. Chowdhery, A., Narang, S., Devlin, J., Bosma, M., Mishra, G., Roberts, A., Barham, P., Chung, H.W., Sutton, C., Gehrmann, S. and Schuh, P., 2023. Palm: Scaling language modeling with pathways. Journal of Machine Learning Research, 24(240), pp.1-113.

  7. Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F. and Sanghai, S., 2023. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. arXiv preprint arXiv:2305.13245.

  8. Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S., Matena, M., Zhou, Y., Li, W. and Liu, P.J., 2020. Exploring the limits of transfer learning with a unified text-to-text transformer. The Journal of Machine Learning Research, 21(1), pp.5485-5551.

  9. Touvron, H., Martin, L., Stone, K., Albert, P., Almahairi, A., Babaei, Y., Bashlykov, N., Batra, S., Bhargava, P., Bhosale, S. and Bikel, D., 2023. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288.

  10. Shoeybi, M., Patwary, M., Puri, R., LeGresley, P., Casper, J. and Catanzaro, B., 2019. Megatron-lm: Training multi-billion parameter language models using model parallelism. arXiv preprint arXiv:1909.08053.

  11. Almazrouei, E., Alobeidli, H., Alshamsi, A., Cappelli, A., Cojocaru, R., Debbah, M., Goffinet, É., Hesslow, D., Launay, J., Malartic, Q. and Mazzotta, D., 2023. The Falcon Series of Open Language Models. arXiv preprint arXiv:2311.16867.

  12. Jiang, A.Q., Sablayrolles, A., Mensch, A., Bamford, C., Chaplot, D.S., Casas, D.D.L., Bressand, F., Lengyel, G., Lample, G., Saulnier, L. and Lavaud, L.R., 2023. Mistral 7B. arXiv preprint arXiv:2310.06825.

  13. Mixtral of experts: A high quality Sparse Mixture-of-Experts.