Attention in Transformers
Part II of An Interpretability Guide to Language Models
Categories:
At the end of the previous part of this series, we trained a zero-layer transformer[1] which consisted of an embedding operation following by an unembedding operation, and we saw that this type of model is severely limited by it’s inability to consider more than just pairs of tokens, or bigrams, in isolation.
That is, given some token \(t_{i}\), our zero-layer model could predict the next most likely token, \(t_{i+1}\); but any tokens that came before \(t_{i}\) could not be taken into account. This limitation caused the model to produce such nonsense as this:
'Once upon a big, "I want to play with'
Recall that internally the model represents each input token by a vector with \(d\) dimensions, so that when some text composed of \(N\) tokens is provided, the full input the model can be represented by an \( N \times d \) matrix.
Conceptually, we can think of the model input as having two axes, the feature axis (running horizontally below), and the temporal axis (running vertically).
As we walk along the feature axis, we encounter more information about one particular token; if we walk along the temporal axis, we are stepping from one token to another. An effective language model should be able to incorporate information along both axes when predicting the next token. Our zero-layer model was only able to use the current token’s embedding, and thus lacked any way to “step back” along the temporal axis to access information about previous tokens.
The temporal axis is a particularly tricky beast, as it varies from one input to the next, depending on the number of tokens provided; yet it is exceedingly important in language modeling, as language naturally has many temporal dependencies (i.e. words depend on the words preceding them, which depend on the words preceding them, and so on).
The ability of neural networks to deal with these temporal dependencies is the subject of this page. Of course, based on the title, you can probably guess that transformers achieve this with an attention mechanism. But before we inevitably discuss attention, I think it is worth taking a few moments to examine how we got here, and to see some of the other ways that temporal dependencies have been incorporated into language models to extend their capabilities beyond the world of bigram.
\(n\)-gram Models
Some of the early notable work applying neural networks to language modeling done by Bengio, Yoshua, et al.[2] exemplifies perhaps the simplest and most intuitive way to break out of the world of bigrams – namely, by designing a model specifically to make its predictions based on trigrams, 4-grams, or, generally, \(n\)-grams.
A high-level picture of the model architecture is shown below:
There are two major differences between this and the zero-layer transformer models with saw before: first, after the input token IDs are mapped to their corresponding embedding vectors, these token embedding vectors are aggregated into n-gram embedding vectors by concatenating the token embeddings.
The figure below illustrates the mapping from token embeddings to bigram embeddings. For illustration purposes, each element of each vector is marked with the index of the token to which it corresponds.
The other difference is the addition of a linear layer applied to each of the n-gram embedding vectors, which allows information about each of the tokens in that n-gram to be exchanged. Thus, the vector that ultimately gets passed in to the unembedding operation no longer represents just a single token, but a series of n tokens.
An example PyTorch implementation of such an \(n\)-gram language model is shown below:
class NGramLanguageModel:
def __init__(self, vocab_size: int, model_dimension: int, hidden_dimension: int, ngram_order: int, bos_token_id: int):
super().__init__()
self.model_dimension = model_dimension
# An n-gram model predicts one token based on the previous n-1 tokens,
# so the number of input tokens processed in each window is n-1.
self.window_size = ngram_order - 1
# The model uses a bos token to pad the input sequence.
self.bos_token_id = bos_token_id
self.embed = torch.nn.Embedding(vocab_size, model_dimension)
self.H = torch.nn.Linear(self.window_size * model_dimension, hidden_dimension)
self.unembed = torch.nn.Linear(hidden_dimension, vocab_size, bias=False)
def forward(self, token_ids):
batch_size, seq_length = token_ids.shape
# Pad the input sequence so that the number of (n-1)-grams is equal to the number of tokens.
padded_token_ids = torch.nn.functional.pad(token_ids, (self.window_size-1, 0), value=self.bos_token_id)
embeddings = self.embed(padded_token_ids)
# Aggregate the token embeddings into (n-1)-gram embeddings
n_gram_embeddings = embeddings.unfold(1, self.window_size, 1).transpose(-1, -2).reshape(batch_size, seq_length, -1)
hidden = torch.nn.functional.tanh(self.H(n_gram_embeddings))
return self.unembed(hidden)
If we train a model of this type, we can see that, depending on the value of \(n\) provided, the model’s output is a bit more coherent than the zero-layer transformer’s output. For example, the following output came from a 5-gram model given the input text "Once"
:
'Once upon a time, there was a little girl named Lily. She loved to play'
These types of models, however, are limited by the small, fixed context window which they are able to observe (The authors of the original paper only go as far as \(n=3\)).
Recurrent Neural Nets
Later efforts to extend the context window of neural network language models saw success using Recurrent Neural Networks (RNNs). While there are many variations on the theme of RNNs, one of the most basic implementations is what came to be called the “Elman Network” introduced by Elman, J.L[3]. Here again, we have the same familiar notion of embedding and unembedding operations at either end of the model (although Elman uses the terms Input and Output), and a hidden layer in the middle similar to the \(n\)-gram model from before.
Rather than concatenating the token embedding vectors into \(n\)-gram embedding vectors, RNNs solve the problem of temporal dependencies by processing each token one at a time, starting with the first token in the sequence. The output of the hidden layer for each token (which is shown below as \(h_{i}\), or the hidden state of token \(i\)) is passed as input when processing the next token, and so on, so that the input to the hidden layer is always the current token embedding, \(t_{i}\), as well as the previous token’s hidden state, \(h_{i-1}\).
Thus, an RNN can be thought of (and implemented as) a for loop over the tokens in the input sequence, updating the hidden_state
at each iteration:
class RNNLanguageModel:
def __init__(self, vocab_size: int, model_dimension: int, hidden_dimension: int):
super().__init__()
self.model_dimension = model_dimension
self.hidden_dimension = hidden_dimension
self.embed = torch.nn.Embedding(vocab_size, model_dimension)
self.hidden_state_layer = torch.nn.Linear(model_dimension, model_dimension)
self.unembed = torch.nn.Linear(model_dimension, vocab_size, bias=False)
def forward(self, token_ids):
embeddings = self.embed(token_ids)
sequence_length = embeddings.shape[1]
outputs = []
hidden_state = torch.zeros(1, self.model_dimension, device=embeddings.device, dtype=embeddings.dtype)
for i in range(sequence_length):
hidden_state = torch.nn.functional.tanh(embeddings[:, i] + self.hidden_state_layer(hidden_state))
outputs.append(self.unembed(hidden_state))
return torch.stack(outputs, dim=1)
The ability of an RNN to handle any length of input sequence without slicing it up into isolated \(n\)-grams is great, but these models comes with their own baggage, notably that they are difficult to train efficiently due to their inherently sequential nature, and that they struggle to capture long-term dependencies, as the state of the entire input sequence has to be crammed into a single state vector, \(h_{t}\).
Of course, there are several ways of augmenting RNNs to improve their ability to model long sequences of text (for example, by using LSTMs in place of the Elman-style RNN described above[4]), but our focus here is on an altogether different mechanism for capturing temporal dependencies.
Attention in RNNs
Attention mechanisms and RNNs are not mutually exclusive. In fact, the first attention mechanisms applied to natural language tasks were used in tandem with RNN models[5] in the domain of a Neural Machine Translation (NMT) (that is, the use of neural networks to translate text from one language to another).
These models were in vogue in 2014, as significant progress was being made over earlier statistical methods. NMT models in this era typically implemented an encoder-decoder architecture, where both the encoder and the decoder were recurrent neural nets. The encoder RNN takes the sentence in the source language and encodes it into a fixed-length context vector; then the decoder takes that context vector and attempts to map it to the equivalent text in the target language.
A major bottleneck in this approach is the context vector, which is meant to encode information about the entire input sequence into a fixed number of dimensions. As the inputs sequence grows larger, this becomes increasingly difficult.
A better option (proposed by [5]) would be to use a unique context vector for each output token, which contains only the information necessary for producing that particular token. But how to create such a targeted context vector? We cannot simply do this based on position (e.g. we can’t just use the hidden state of input token \( t_i \) as a context vector for computing the output token \(y_i\) as word order in not necessarily preserved between languages).
The authors propose to do this by training (alongside the encoder and decoder) an alignment model (in this case, a single-layer feedforward neural network), whose job is to compute an alignment score, \(e_{ij}\), between the token currently being decoded (given by the decoder RNN hidden state \(s_{i-1}\)), and a encoded input token (given by an encoder RNN hidden state \(h_j\)).
Once we have computed the alignment scores between the token being decoded and each of the input tokens, we apply a softmax function so that the alignment scores all sum to 1, and then compute the unique context vector, \(c_i\), for the output token at index \(i\) as a weighted sum of the encoder hidden states and their corresponding alignment scores.
The resulting token-specific context vector, then, will be some linear combination of the encoder hidden states; and, importantly, it will be most similar to those hidden states which have the highest alignment scores. It is this ability of the decoder to dynamically “choose” to incorporate (or to discard) information from the various encoder hidden states which we refer to as an attention mechanism.
This kind of architecture – encoder-decoder RNNs with attention – was quite successful and would ultimately come to power Google Translate[6] for several years before the advent and adoption of transformers.
Attention in Transformers
The transformer architecture[7] came about as an extension (or perhaps a simplification) of the idea of encoder-decoder RNNs with attention. Specifically, the original transformer employs the same encoder-decoder structure as above, but removes the recurrent nature of both encoder and decoder, and instead relies entirely on attention mechanisms to move information along the temporal axis.
The Attention operation as it exists in transformer models can be considered as consisting of two stages:
- Attention Pattern Computation - This is akin to the computation of alignment scores in the previous example.
- Attention Output Computation - This is akin to computation of the context vector and the use of that context vector in producing the next token.
Attention Pattern Computation
This first stage is typically described as follows:
- Map each input vector \(x_{i}\) to a query vector \(q_{i}\) using the learned Query Projection weight \(W_{Q}\).
- Map each input vector \(x_{i}\) to a key vector \(k_{i}\) using the learned Key Projection weight \(W_{K}\).
- Compute the attention pattern, \(A\), by taking the dot product of every query vector with every key vector and then applying a softmax function.
Mathematically, these three steps can be expressed like this:
A brief note on this idea of query and key vectors, which has become ubiquitous in descriptions of attention: at first glance, this might seem quite different from how attention worked in the NMT encoder-decoder model from earlier, where an alignment model was used to compute a score between an encoder hidden state and a decoder hidden state. But these are simply two different ways of talking about the same thing; we could just as easily say that attention in transformers relies on an alignment model defined as:
Note that this is not all that different from the alignment model used by [5], which was defined as:
where \(v_a\), \(W_a\), and \(U_a\) are the learned weights of the alignment model.
One difference worth noting between these two “alignment models” is that, in transformers, both parameters, \(x_i\) and \(x_j\), come from the same sequence of input vectors, whereas in the NMT RNN model, one parameter comes from the decoder, and the other from the encoder. This is why you will often hear the attention mechanism in transformers referred to as “self-attention”, as it computes scores between different positions in the same sequence.
Returning now to the formula for attention pattern computation: the resulting attention pattern matrix, \(A\), is an \( N \times N \) matrix, where each value \( a_{i,j} \) represents the degree to which the \(i^{th}\) token in the sequence “attends” to the \(j^{th}\) token in the sequence (if this notion of one token “attending” to another feels a bit too magical, don’t worry, we will remedy this shortly).
We can rearrange the above formula for attention pattern computation to make things a little simpler to interpret. Specifically, we can eliminate the need for separate query and key projection weights by combining them into a single weight matrix which we will refer to as \(W_{QK}\):
Attention Output Computation
The second stage of the attention operation can be described as follows:
- Map each input vector \(x_{i}\) to a value vector \(v_{i}\) using the learned Value Projection weight \(W_{V}\).
- Multiply the value vectors by the previously computed attention pattern, \(A\).
- Multiply the result of step 2 by the learned Output Projection weight \(W_{O}\).
Taking these three steps together, we can define the attention output computation as:
Here again the computation of the attention output, \(Z\), can be simplified by defining a single learned weight matrix \( W_{VO} = W_{V}W_{O} \).
The entire attention operation can thus be described by the following two equations, each parameterized by a single matrix of weights:
Transformer models almost all employ some form of multi-head attention, meaning that each attention layer in the model consists of several independent attention operations running in parallel. This fact does not materially change the description of the two stages of attention above; we just need to recall that each attention head has its own independent copy of \(W_{QK}\) and \( W_{VO} \).
To account for this, we will use the following notation:
Attention-Only Transformers & The Residual Stream
The output of the attention operation, \(Z\), is a matrix of shape \(N \times d_{model} \); that is, the same shape as the input to the operation. This allows for residual connections to be introduced around each attention operation in the model.
For example, consider a one-layer attention-only transformer model (the next step up in complexity from the zero-layer model we saw in the previous article). This model consists of embedding, a single attention operation, followed by unembedding.
Importantly, because of the residual connection, the input to the unembedding operation is the sum of the embeddings and the output of each attention head:
This use of residual connections around each successive attention operation is of particular interest in interpretability research, and has given rise to the concept of the residual stream.
The term residual stream refers specifically to the vectors which serve as input to each attention operation in the model, and to which each operation’s output is added. Each token, then, has its own residual stream, which is at first initialized to the embedding corresponding to that token, but which is modified by each successive operation.
Below is the basic forward pass of an attention-only transformer which illustrates how the residual stream fits in:
def forward(self, input_ids,):
input_embeds = self.embed(x)
# initialize the residual stream to the input embeddings
residual_stream = input_embeds
for attention in self.attention_layers:
# execute the attention operation, passing the residual stream as input
attention_output = attention(residual_stream)
# "write" the output of the attention operation back to the residual stream
residual_stream = attention_output + residual_stream
return self.unembed(residual_stream)
Note that the terms embeddings or hidden states are often used elsewhere to refer to the same thing. Here, we will use the term residual stream, which highlights its function as an information channel, from which each operation “reads” its input, and to which the result of each operation is “written”. In this sense, we can think of the residual stream as a sort of buffer or communication channel that allows operations in early layers to pass information to be used by later layers.
Attention Interpretability
With the terminology and background on attention out of the way, we can start to explore what function attention actually serves in a trained language model. At a high level, the attention operation moves information from the residual stream of one token to the residual stream of another; or, in other words, it copies information from one place to another along the temporal axis.
Analyzing Attention Patterns
Any copy operation needs both a source and a destination location. The attention pattern matrix, \(A\), can be viewed as specifying these parameters. Some example attention patterns matrices computed by GPT2 are shown below (these were computed by providing the text Beauty will save the world
):
Side Note: You may notice that these attention pattern matrices contain only 0s above the main diagonal. This is not by chance. In text generation models, the attention pattern is typically masked to zero out all values above the main diagonal. This is done to ensure that information can only be copied from earlier to later tokens in the input sequence, and not vice versa. This is the difference between so-called “decoder-only” transformers, and “encoder-only” transformers, where the attention operation can copy information bidirectionally.
The attention pattern matrix contains a row and column for each token in the input, and values ranging from 0.0 to 1.0 (due to the softmax function). We can think of each element \(a_{i,j}\) in this matrix as specifying “how much” to copy from token \(i\) to token \(j\).
For example, imagine an extreme scenario where the attention pattern contains 1s in the first column and 0s everywhere else, like this:
When we go to compute the attention output using the equation \(Z = AxW_{VO}\) the first thing we do is multiply the attention pattern \(A\) and the input \(x\), which contains the residual streams of each token. This multiplication results in the first token’s residual stream being “copied” to every other position. This is demonstrated in the Python code below:
import torch
from transformers import AutoModel, AutoTokenizer
# load gpt2 from huggingface
model = AutoModel.from_pretrained("openai-community/gpt2")
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
# tokenize the input text
text = " Beauty will save the world"
input_ids = tokenizer.encode(text, return_tensors="pt")
# get the token embeddings
embeddings = model.wte(input_ids)
# create an attention pattern with 1s in the first column
fake_attention_pattern = torch.zeros(5, 5)
fake_attention_pattern[:, 0] = 1.0
print("Attention patten (A):")
print(fake_attention_pattern)
print("\nResidual streams (x):")
print(embeddings)
print("\nCopy operation result (Ax):")
print(fake_attention_pattern @ embeddings)
Attention pattern (A):
tensor([[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.]])
Residual streams (x):
tensor([[[ 0.0762, -0.0892, -0.0347, ..., 0.0968, 0.0225, 0.0920],
[ 0.0641, 0.0195, 0.0691, ..., -0.0154, -0.1325, -0.1044],
[-0.0497, 0.0060, 0.1576, ..., -0.0856, 0.0492, -0.0557],
[-0.0393, 0.0050, 0.0421, ..., -0.0477, 0.0670, -0.0471],
[-0.1488, 0.1519, 0.0056, ..., -0.3107, 0.2073, 0.0377]]])
Multiplication result (Ax):
tensor([[[ 0.0762, -0.0892, -0.0347, ..., 0.0968, 0.0225, 0.0920],
[ 0.0762, -0.0892, -0.0347, ..., 0.0968, 0.0225, 0.0920],
[ 0.0762, -0.0892, -0.0347, ..., 0.0968, 0.0225, 0.0920],
[ 0.0762, -0.0892, -0.0347, ..., 0.0968, 0.0225, 0.0920],
[ 0.0762, -0.0892, -0.0347, ..., 0.0968, 0.0225, 0.0920]]])
Of course, in most cases the attention pattern will not contain just 1s and 0s, but still I find it is helpful to think of \(A\) as describing a copy operation, even if it can be a quite messy copy, where information is partially copied from multiple different residual streams.
Attention patterns can be quite visually appealing, and a lot of interesting research has been done to analyze and interpret real attention patterns generated by trained transformer models. For example, [8] examined the attention patterns generated by BERT[9] and observed 5 distinct “classes” of attention pattern, examples of which are illustrated below (note that BERT is an encoder-only model, so non-zero values above the main diagonal are allowed):
Tools like the open-source bertviz
[10] provide additional ways to visualize attention patterns in transformer models. Despite their visual appeal, however, it is up for debate how much we can truly glean from analyzing attention patterns alone. To my knowledge, there has not been any similar identification of repeating attention pattern “classes” in larger, more recently trained models; and subsequent research[11] has suggested that attention patterns do not provide meaningful explanations of end-to-end model behavior.
Circuits in Toy Models - \( W_{QK} \) and \( W_{VO} \)
If the attention pattern, \(A\), can be thought of as describing a copy operation, how should we think about the two learned weight matrices involved in self-attention, \( W_{QK} \) and \( W_{VO} \)? In this section, we will examine these weight matrices in a one-layer attention-only language model.
The first weight matrix, \(W_{QK}\), is involved in the computation of the attention pattern through the equation \(A = xW_{QK}^{T}x^{T}\).
Imagine passing into our one-layer model a huge sequence containing one of every possible token in the model’s vocabulary. The input matrix \(x\), then, contains an embedding vector for every possible token; in other words, \(x\) is equivalent to the entire embedding layer weight matrix, \(W_{E}\). The attention pattern for this huge imaginary sequence would be the \(V \times V\) matrix given by:
This attention pattern is particularly interesting because it contains an attention score between 0 and 1 for every possible pair of tokens in the vocabulary. This is what [1] terms the Query-Key Circuit, which governs the degree to which each destination token will copy information from every other source token.
The second weight matrix, \(W_{VO}\), is involved in the computation of the attention output through \(Z = AxW_{VO}\).
Recall that in a one-layer attention-only transformer, after the attention output is computed, it is added to the residual stream, and then unembedding is performed to convert the residual stream to logits.
If we consider again the imaginary sequence containing every possible token, the entire flow from tokens to logits can be defined by
By distributing the multiplication by \(W_{U}\), we can see that the logits in a one-layer attention-only model are, perhaps unsurprisingly, computed as the sum of two separate terms:
We saw this first term, \( W_{E}W_{U} \), already when looking at the zero-layer transformer, which was capable only of predicting the next token based on bigram statistics. The second term, \( AW_{E}W_{VO}W_{U} \), can be viewed as the matrix product of two different \(V \times V\) matrices, \(A\) and \(W_{E}W_{VO}W_{U}\)
We already know that \(A\) defines the amount of information to copy from each source token to each destination token. The second \(V \times V\) matrix, \( W_{E}W_{VO}W_{U} \), which [1] refers to as the Output-Value Circuit, thus governs the attention pattern’s overall affect on the logits. That is, the Output-Value Circuit answers questions such as “If I copy from the residual stream of some source token \(t_{src}\) to some destination token \(t_{dst}\), how will this affect the logits for some third token \(t_{out}\)?”
Query-Key and Output-Value Circuits in Action
All these equations can be a bit difficult to wrap one’s head around, so let’s see an example of the QK and OV Circuits in action. For these examples, we will use a toy model with one attention layer and four attention heads.
First we will load the model and tokenizer:
model = AttentionOnlyTransformer(vocab_size=50257, model_dimension=128, num_layers=1, num_heads=4, max_length=128)
model.load_state_dict(safetensors.torch.load_file("model.safetensors"))
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
Then we will extract all of the relevant weight matrices involved in the QK and OV Circuits:
W_E = model.embed.weight
W_U = model.unembed.weight.T
# W_QK and W_VO weights are unique to each attention head
W_QK = []
W_VO = []
QK_circuit = []
OV_circuit = []
for head_idx in range(4):
# Get the query and key weights for the current attention head
W_Q = model.layers[0].W_Q[head_idx]
W_K = model.layers[0].W_K[head_idx]
# Get the value and output weights for the current attention head
W_V = model.layers[0].W_V[head_idx]
W_O = model.layers[0].W_O[head_idx]
# Compute the QK and OV circuits for the current attention head
W_QK.append(W_Q @ W_K.T)
QK_circuit.append(W_E @ W_QK[head_idx] @ W_E.T)
W_VO.append(W_V @ W_O)
OV_circuit.append(W_E @ W_VO[head_idx] @ W_U)
print("QK circuit:", QK_circuit[0].shape)
print("OV circuit:", OV_circuit[0].shape)
QK circuit: torch.Size([50257, 50257])
OV circuit: torch.Size([50257, 50257])
This toy model has a vocabulary size of 50,257, thus each “circuit” matrix is of shape \( 50257 \times 50257 \). But what do these matrices mean exactly?
Let’s first consider a single entry in the QK-Circuit matrix, corresponding to the tokens world
and beauty
:
print("beauty:", tokenizer.encode(" beauty"))
print("world:", tokenizer.encode(" world"))
print("QK[world, beauty]:", QK_circuit[0][995, 8737].item())
beauty: [8737]
world: [995]
QK[world, beauty]: 13.76384162902832
If we then compute the attention pattern (prior to applying softmax) for a sequence that contains both of these tokens, will we see that the attention score between the token world
(referred to as the query or destination token) and the token beauty
(referred to as the key or source token) is the same as the entry in the QK-Circuit matrix (see the bottom-leftmost attention score in the output below):
def get_attention_scores(text, head_idx=0):
input_ids = tokenizer.encode(text, return_tensors="pt")[0]
x = model.embed(input_ids)
q = x @ model.layers[0].W_Q[head_idx]
k = x @ model.layers[0].W_K[head_idx]
attention_scores = q @ k.T
return attention_scores
print(get_attention_scores(" beauty will save the world"))
tensor([[-11.9188, -4.7118, -5.0762, -23.4143, -9.2328],
[ -3.5618, 5.2131, 5.5276, -7.5997, 2.7818],
[ 4.9187, 9.5050, 1.9835, 23.4580, -2.7324],
[-11.2116, 13.7217, -15.7884, -1.0675, -10.5773],
[ -6.0198, -2.1840, -6.1825, -26.1244, 7.7341]])
Now what about the OV-Circuit? We claimed earlier that the OV-Circuit tells us how the logits will be affected if a given source token is attended to (i.e. copied from). Let’s consider another example sequence, The little dog was named
.
If we pass this sequence to our model, but bypass the attention layer, we essentially simulate a zero-layer transformer. The logits for the top 10 predicted tokens without attention are shown below:
text = "The little dog was named"
# Compute the logits without attention
x = model.embed(tokenizer.encode(text, return_tensors="pt"))
logits = model.unembed(x)
# Check the tokens corresponding to the top 10 logits
top_k = torch.topk(logits[0, -1], 10)
for (token_id, logit) in zip(top_k.indices, top_k.values):
print(f"{tokenizer.decode(token_id):<8} {logit.item():.3f}")
Sam 3.393
Sara 3.297
Tim 3.207
Lucy 3.143
Bob 3.092
Tom 3.072
Sue 2.908
's 2.673
Lily 2.432
Anna 2.406
Most of these predictions make sense in the world of bigrams. Because the previous token was named
, the model can guess that the next token will likely be a name. Let’s now pass the same sequence in, this time executing the attention layer as well:
text = "The dog was named"
input_ids = tokenizer.encode(text, return_tensors="pt")
# Compute the logits with attention
logits = model(input_ids)
# Check the tokens corresponding to the top 10 logits
top_k = torch.topk(logits[0, -1], 10)
for (token_id, logit) in zip(top_k.indices, top_k.values):
print(f"{tokenizer.decode(token_id):<8} {logit.item():.3f}")
Max 16.601
Spot 15.735
Buddy 15.300
Bow 14.026
Luna 12.550
Spike 12.227
Ben 12.138
Tiny 11.497
Bob 11.394
Sam 11.318
The model again predicts that the next token will be a name; this time, however, the top names predicted are much more typically viewed as names for dogs like Spot
or Buddy
.
If we look at the attention pattern generated for this sequence (the attention pattern computed by attention head 3 is shown below) we will see that the last token, named
, attends strongly to the token dog
:
print(get_attention_scores("The little dog was named", head_idx=3))
tensor([[ 0.7997, -0.4273, 14.6899, 0.2682, 3.6517],
[ 4.4186, -1.2725, -4.3667, -2.7908, 1.5465],
[ 6.3078, -3.9695, -3.5616, -6.0423, 4.6949],
[ 1.6469, -1.2187, -2.3154, -0.4845, 1.8959],
[ 5.5127, 5.5340, 18.4230, 0.2630, 5.5345]])
⭡
"dog"⭢"named"
This means that, after applying this attention pattern, the residual stream of the token named
will most closely resemble the residual stream of the token dog
. Now let’s turn to the OV-Circuit.
The string " dog"
corresponds to token ID 3290
in this model’s vocabulary:
tokenizer.encode(" dog")
[3290]
Thus, row 3290
in the OV-Circuit matrix will show how the logits of every other token will be affected when the token dog
is attended to. For example, the code below prints the top few values in row 3290
in the OV-Circuit matrix:
tok_k = torch.topk(OV_circuit[3][3290, :], 15)
for (token_id, score) in zip(tok_k.indices, tok_k.values):
print(f"{tokenizer.decode(token_id):<8} {score.item():.3f}")
dog 7.058
Spot 6.756
Max 6.219
bark 6.062
collar 5.735
Buddy 5.667
bone 5.561
Max 5.416
Spot 5.250
leash 4.937
mouth 4.750
owner 4.492
heel 4.459
barking 4.383
tail 4.300
When the attention pattern (governed by the QK-Circuit) copies information from the source token dog
, these tokens in the list above are the ones for which the corresponding logits are most increased. Included in this list are a couple of the “dog names” that we saw earlier, as well as some other “dog-related” tokens like bark
, bone
, and tail
. By analyzing the QK-Circuit and OV-Circuit matrices, then, we can have a pretty complete idea of how the attention layer is functioning internally, and of how it will impact the final predictions made by our one-layer model.
In models with more than one layer, however, this picture becomes more complicated, where some attention heads may copy information between residual streams not to directly impact the logits for some token, but so that this information can be used “downstream” by attention heads in subsequent layers. Along with this increased complexity of multi-layer models, there remains still a major component of the transformer architecture that we have not discussed: multi-layer perceptrons, or MLPs.
These topics, however, we will leave for later parts of this series. Thanks for reading.
References
-
Elman, J.L., 1990. Finding structure in time. Cognitive science, 14(2), pp.179-211.
-
Vaswani, A., 2017. Attention is all you need. Advances in Neural Information Processing Systems.
-
Kovaleva, O., 2019. Revealing the Dark Secrets of BERT. arXiv preprint arXiv:1908.08593.
-
Jain, S. and Wallace, B.C., 2019. Attention is not explanation. arXiv preprint arXiv:1902.10186.