Zero-Layer Transformers
Part I of An Interpretability Guide to Language Models
Categories:
This is the first part in a series of posts I am planning on the topic of language model interpretability. I don’t claim any particular expertise or authority on this subject: these pages, rather than serving as any comprehensive survey or review, contain the notes and early thoughts of someone beginning to explore the subject of interpretability. I hope you may find them helpful.
Interpretability as a discipline aims to explain or reverse engineer the behaviors of deep neural networks, in some ways analogous to how neuroscience aims to explain the many behaviors of biological brains (for more on the merits and limitations of this analogy, see Chris Olah’s post[1] on this).
The notion of “interpreting” or “reverse engineering” the behaviors of artificial neural networks should at first seem odd – after all, these networks are designed and trained my humans. The implication, then, is that the humans designing and training these networks do not really understand how they work.
In one sense – that is, an architectural sense – the engineers and scientists working on artificial neural networks know perfectly well how these models work: they know how many trainable parameters the model has, how many layers it has, which mathematical operations are performed in what order by each of the layers, and so on.
But, in another sense – what we may call a behavioral sense – we are still far from a foolproof understanding of all the behaviors exhibited by trained language models. I find this limitation is particularly exemplified in how we typically attempt to mitigate strange or undesired outputs from LMs, namely, by performing a sort of “trial and error” approach involving tweaking our prompts or system prompts, or by playing with sampling parameters.
The well-studied phenomenon of Neural Scaling Laws[2] when training language models also illustrates how we have reached this stage where the capabilities of models have well outpaced our ability to understand them deeply. What these Scaling Laws show us, empirically, is that more trainable parameters (i.e. larger models) and more training data leads to better performance; but what they fail to tell us is precisely why this is.
Language Models
Large Language Models are, when treated as black box systems, essentially “next-word predictors”. That is, given a sequence of words, the model predicts which word is most likely to come next.
For example, given the words “I opened the”, a model might predict, with some probability, that the most likely next word is “door”. Other possible words, like “box” and “package”, may also be given some relatively high probability.
The probabilities below are from GPT-2[3]:
From the outside, looking at the model as a black box, we could imagine that it is making its predictions based on some large joint probability distribution; that is, we may imagine that the model internally consists of a large table of probabilities (one entry of which being “The probability of the word door occurring after the words I opened the).
. . . | . . . |
$$P(\text{door} \mid \text{I opened the})$$ | $$0.26$$ |
$$P(\text{box} \mid \text{I opened the})$$ | $$0.086$$ |
$$P(\text{package} \mid \text{I opened the})$$ | $$0.027$$ |
. . . | . . . |
Indeed, with such a table, one could theoretically implement a very effective language model, which we might term a “Chinese Room Language Model” after Searle’s famous thought experiment[4]. If we really had such a Chinese Room Language Model, interpretability would be quite a different – and arguably less interesting – problem: we could simply open up the model, extract the probability table, and see how the model would behave in every possible scenario without actually ever executing it. We could even quite effortlessly tweak the model to respond in a particular way in a particular scenario by modifying the probabilities.
Sadly, though, we have no such Chinese Room Language Model, and when we consider the logistics of building such a thing, we can quickly see that this approach to language modeling is infeasible. For example, a typical language model might have a vocabulary of around 50,000 unique words, or tokens. To be able to predict the next token even for tiny sentences of 10 tokens each, our table of probabilities would need to contain \(50,000^{10} - 1\) entries.
Assuming we could represent each probability with a single byte, this would require \(9.7 \times 10^{46}\) bytes (or around \(9.7 \times 10^{16}\) quettabytes, if one believes in such a thing).
Real language models, then, must have some tricks up their proverbial sleeves which allow them to simulate this giant joint probability distribution without actually taking up the enormous amount of space that it would naively require. To begin to explore these tricks, we will first turn to the Transformer Architecture to understand the underlying structure of a real world language model.
Transformer Architecture
The world of transformer model interpretability research is broad and changing quickly. What I will aim to do here is to step one at a time through the different operations which together comprise the transformer architecture, and, with each step, to explore some of the findings and hypotheses relevant to the interpretation of each operation’s high-level function.
At a high level, Transformers consist of four different types of operations:
- Embedding: Maps integer tokens IDs to embedding vectors
- Self-Attention: Mixes information along the “time” or “sequence” dimension by allowing each token to selectively “attend” to prior tokens in the sequence.
- Multi-Layer Perceptron (MLP): Feedforward neural network that operates independently on each token.
- Unembedding - Maps embedding vectors to logits.
These definitions are admittedly hand-wavy. Later on, we will look closely at each operation and describe as best we can, from an interpretability perspective, how each contributes to the end-to-end capabilities of the model.
The Self-Attention and MLP operations together comprise what we will call a Transformer Block. Transformer models typically consist of several Transformer Blocks stacked sequentially such that the output of one block becomes the input to the next.
The figure below illustrates a transformer model with a single Transformer Block:
In this first chapter, we will limit our examination to the Embedding and Unembedding operations and to the so-called “Zero-Layer Transformer Models”[5] that we can construct from these two operations.
Embedding
Before a language model can make some prediction about a piece of text (e.g. predicting the next word), this text must first be presented in a form recognizable by the model. The Embedding operation maps each token in the model’s vocabulary to a learned vector of length \(d\), the Model Dimension
Throughout, we will use the term dimension to refer to each individual component or element of a vector. The embedding vectors in a transformer model can thus be referred to as \(d\)-dimensional vectors.
In the same way that a 2-dimensional vector or a 3-dimensional vector can be thought of as describing a direction and magnitude in 2-dimensional or 3-dimensional space, a \(d\)-dimensional vector can be thought of (although not easily visualized) as describing a direction and magnitude in \(d\)-dimensional space.
Embedding vectors are distributed representations of the tokens in the model’s vocabulary, as information about the token is distributed across all dimensions of \(d\)-dimensional space. This is in contrast to local representations like one-hot encoding, where each dimension is reserved for a particular token.
The use of distributed representations for words provides some benefits over local representations, such as allowing the embedding vectors of similar words to point in similar directions in \(d\)-dimensional space and vice versa, which in turn improves a model’s ability to generalize beyond the data it saw during training (see [6] and [7] for more on the merits of distributed word representations).
One challenge that distributed representations present from an interpretability standpoint, however, is that, when examining embedding vectors, we have no reason to assume that any one dimension has any meaning on its own.
For example, we may think that there is something a little special about a vector like \([2.0, 0.0, 0.0]\), because it runs just along what we might call the x-axis. But to a transformer model, this may be just another direction in \(d\)-dimensional space, no more special than, say, \([-1.2, 0.3, 2.1]\).
Relationships between Embeddings
Lots of interesting research on embeddings – some of which predates the transformer architecture altogether – has focused on the relationships between learned word embedding vectors. For example, [8] observed that meaningful semantic relationships between words are represented as constant vector offsets between word pairs, making the now somewhat famous observation:
These interesting relationships between embedding vectors emerge only in models more complex and capable that the kind we will build in this first part of the series, so we will defer further examination of this topic for now.
Embeddings as a Matrix Multiplication
Although the Embedding operation is often described as a mapping from integer Token IDs to embedding vectors, we can also think of it as a matrix multiplication between a one-hot-encoded token vector \(t\), and the learned embedding weight matrix \(W_{E}\).
For here on, we will refer to this embedding matrix as \(W_{E}\), where each row is a token embedding vector; and we will use \(t \times W_{E} \) to refer to the process of mapping a one-hot encoded token \(t\) to its corresponding embedding vector.
Unembbeding
Just as the Embedding operation is responsible for converting text into \(d\)-dimensional vectors, the Unembedding operation is unsurprisingly responsible for converting \(d\)-dimensional vectors back into text.
Specifically, Unembedding can be viewed as a single matrix multiplication between the learned Unembedding Weight matrix, \(W_{U}\) and a \(d\)-dimensional vector, which we will call \(x\).
The result, then, is a \(V\)-dimensional vector of logits, where each element \(v_{i}\) of the vector indicates how likely it is that the next token will be the token at index \(i\) in the vocabulary. The figure below shows how we can intepret this logits vector, using a toy vocabulary of 9 different tokens:
As a side note, some language models have set \(W_{U}\) to be equal to the transpose of the Embedding weight matrix \(W_{E}\) based on the findings from [9], but this practice seems to have fallen out of favor in recent years, perhaps in part due to later work like [10].
Logits
We referred to the output of the Unembedding operation as a vector of logits, but what exactly is a logit? And what is it about Unembedding that causes it to produce logits as output, rather than just any old vectors?
A quick search online for logits will probably yield some answer like “The Logit is a type of function that maps probability values to real numbers”. But we get the feeling that this is not exactly what we are looking for.
Traditionally in stats, the Logit Function is a function that maps probabilities (i.e. values between 0 and 1) to real numbers (ranging from \( -\infty \) to \( \infty \) ), and is defined as:
Where the input \(p\) is a probability, and the output is a real number that we refer to as a Logit. Now, the Logit Function happens to be the inverse of the Sigmoid function, which maps real numbers to probabilities:
Thus, the relationship between logits and probabilities can be described by the following diagram:
But what does this have to do with language models and the Unembedding operation? We saw much earlier how the output of a language model is a set of probabilities: one probability for each possible token in the model’s vocabulary. Thus, as the final step in computing the model output, we need to apply some function to map the raw model outputs (i.e. the result of the Unembedding operation) to a set of probabilities between 0 and 1.
Unfortunately, we cannot use the Sigmoid function, as this only works in cases were there are two possible outcomes (such as 0 or 1, true or false, and so on), whereas in our case we have \(V\) possible outcomes. Thus, in language modeling we use the Softmax function, which can be viewed as a generalization of the Sigmoid function to cases when any number of mutually exclusive outcomes.
Since the inputs to a Sigmoid function can be rightly called Logits (based on the definition of Logits above), it has become convention in machine learning literature, for better or worse, to refer also to the inputs to a Softmax function as Logits, even though the precise connection with the original Logits function no longer holds.
But enough on this business of nomenclature: from hereon, we will simply follow convention and refer to the outputs of the Unembedding operation as logits, which, when passed through a Softmax function, become probabilities.
The Zero-Layer Transformer
Now, with just these two operations covered thus far, Embedding and Unembedding, we are now able to build a naive sort of language model, which, following [5], we will call a Zero-Layer Transformer, as it can be viewed as a transformer model with zero transformer blocks.
In PyTorch, the Zero-Layer transformer would look something like this:
class ZeroLayerTransformer(torch.nn.Module):
def __init__(self, vocab_size: int, model_dimension: int):
super().__init__()
self.model_dimension = model_dimension
self.embed = torch.nn.Embedding(vocab_size, model_dimension)
self.unembed = torch.nn.Linear(model_dimension, vocab_size, bias=False)
def forward(self, token_ids):
embeddings = self.embed(token_ids)
return self.unembed(embeddings)
Of course, we could choose to model the Embedding operation as another Linear
layer (i.e. as a matrix multiplication) rather than using PyTorch’s Embedding
layer. The only difference would that instead of providing integer token_ids
as input, we would need to provide a one-hot vector for each token. Aside from this, the result would be the same.
This Zero-Layer model consists of nothing but two matrix multiplications, and mathematically is as simple as:
Training a Zero-Layer Model
We can train a Zero-Layer Transformer model to perform next token prediction using a cross entropy loss function. The example script below demonstrates a single training step, with several print statements thrown in to hopefully make things clear:
import torch
from transformers import AutoTokenizer
# Initialize the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
model = ZeroLayerTransformer(vocab_size=tokenizer.vocab_size, model_dimension=128)
# Tokenize some training data
text = "Mistakes are the portals of discovery"
token_ids = tokenizer.encode(text, return_tensors="pt")
print("Token IDs:")
pretty_print_tokens(token_ids, tokenizer)
# Compute the logits
logits = model(token_ids)
print(f"\nLogits Shape: [ {"×".join(str(x) for x in logits.shape)} ] (Batch Size x Sequence Length x Vocab Size)")
print(f"Logits:\n{logits.detach().numpy()}")
# Shift the logits to the right by one position
shifted_logits = logits[:, :-1, :]
print(f"\nShifted Logits Shape: [ {"×".join(str(x) for x in shifted_logits.shape)} ]")
print(f"Shifted Logits:\n{shifted_logits.detach().numpy()}")
# Shift the Token IDs to the left by one position
shifted_token_ids = token_ids[..., 1:]
print("\nShifted Token IDs:")
pretty_print_tokens(shifted_token_ids, tokenizer)
loss = torch.nn.functional.cross_entropy(
shifted_logits.reshape(-1, logits.size(-1)),
shifted_token_ids.reshape(-1)
)
print(f"\nLoss: {loss.item():.2f}")
The output of this script is shown below:
╒═══════════╤═══════╤══════╤═════╤═════╤═════════╤═════╤═══════════╕
│ Token IDs │ 49370 │ 1124 │ 389 │ 262 │ 42604 │ 286 │ 9412 │
├───────────┼───────┼──────┼─────┼─────┼─────────┼─────┼───────────┤
│ Tokens │ Mist │ akes │ are │ the │ portals │ of │ discovery │
╘═══════════╧═══════╧══════╧═════╧═════╧═════════╧═════╧═══════════╛
Logits Shape: [ 1×7×50257 ] (Batch Size x Sequence Length x Vocab Size)
Logits:
[[[ 0.05 0.3 -0.01 0.44 ... 0.43 -0.68 -0.82 1.02]
[ 0.58 -0.81 -0.14 -0.41 ... -0.22 -0.98 0.29 0.59]
[-0.28 0.46 0.11 -0.35 ... 0.39 -0.02 -1.06 0.31]
[-0.96 -0.05 0.11 -0.07 ... 0.29 -0.36 -0.38 -0.22]
[-0.07 0.15 -0.12 0.22 ... -0.5 0.33 -0.84 -0.67]
[-0.19 0.05 0.34 0.23 ... -0.51 -0.19 -0.01 0.19]
[-0.35 0.95 1.21 1.11 ... 0.19 0.21 0.88 -0.26]]]
Shifted Logits Shape: [ 1×6×50257 ]
Shifted Logits:
[[[ 0.05 0.3 -0.01 0.44 ... 0.43 -0.68 -0.82 1.02]
[ 0.58 -0.81 -0.14 -0.41 ... -0.22 -0.98 0.29 0.59]
[-0.28 0.46 0.11 -0.35 ... 0.39 -0.02 -1.06 0.31]
[-0.96 -0.05 0.11 -0.07 ... 0.29 -0.36 -0.38 -0.22]
[-0.07 0.15 -0.12 0.22 ... -0.5 0.33 -0.84 -0.67]
[-0.19 0.05 0.34 0.23 ... -0.51 -0.19 -0.01 0.19]]]
Shifted Token IDs:
╒═══════════╤══════╤═════╤═════╤═════════╤═════╤═══════════╕
│ Token IDs │ 1124 │ 389 │ 262 │ 42604 │ 286 │ 9412 │
├───────────┼──────┼─────┼─────┼─────────┼─────┼───────────┤
│ Tokens │ akes │ are │ the │ portals │ of │ discovery │
╘═══════════╧══════╧═════╧═════╧═════════╧═════╧═══════════╛
Loss: 10.77
It is worth taking a moment to try and understand this business of shifting the logits and token IDs before computing the loss. What we ultimately want to grade the model on is its ability to predict the correct token, \( t_{i+1} \), given all of the preceding tokens, \( t_{0:i} \).
If we were to simply compute the cross-entropy between the logits and the token IDs without any shifting, this would be equivalent to training the model to predict the token \( t_{i} \) given all the preceding tokens, \( t_{0:i} \), including the token it’s trying to predict (this is a bit like giving out a test where every question includes it’s own answer).
Visually, we can imagine drawing an arrow between each of the logits vectors, \( y_{i} \) and each of the input tokens, \(t_{i}\):
The arrows in this case indicate that loss is computed based on how well the logits (in the top row) predict the corresponding token (in the bottom row). If we went on training like this, we would end up with a model that simply spits out the same text that we fed in, which is pretty useless.
Instead, what we want is to grade the model on how well it predicts the next token (for example, the loss should be computed between the logits \(y_{0}\) and the token akes
, and so on). By shifting the logits to the right, and the tokens to the left, we get the correct loss calcuations (you could also consider this not as shifting, but as truncating the last logits vector and the first token ID):
The Limitations of Zero-Layer Models
The script above performs a single training step. We can repeat this process (say, a few thousand times) on lots of different text data, to get a trained Zero-Layer transformer model. We will quickly see, however, that this kind of model is quite unintelligible.
For example, here is some output from a Zero-Layer transformer model when I provide as input the word “Once”:
'Once upon a big, "I want to play with'
Note that here I am performing “Greedy Sampling” by simply selecting, at each step, the token with the highest probability assigned by the model.
One thing we notice in this odd output is that each pair of words in the string actually do make sense together (e.g. “Once upon”, “upon a”, “play with”, and so on); but the output sequence as a whole makes no sense. This is because, when predicting the next token \(t_{i}\), the model only has access to information in the most recent token \(t_{i-1}\); in other words, this Zero-Layer transformer is only capable of predicting the most likely bigrams.
Although admittedly not at all useful, the Zero-Layer transformer is very interpretable. The entire model can be described by two weight matrices: the embedding matrix \(W_E\), and the unembedding matrix \(W_U\). The product of these weight matrices, \(W_{E}W_{U}\), is a \(V \times V\) matrix containing, for each token, the logits associated with every other possible token.
zero_layer_weights = model.embed.weight @ model.unembed.weight.T
print(zero_layer_weights.shape)
torch.Size([50257, 50257])
Thus, just by examining this one matrix, we can understand every behavior of the model. For example, the string token “Once” maps to token Id 7454
in this particular model. We can look at row 7454
of the matrix \(W_EW_U\) and get the indices of the columns with the highest values like so:
top_k_indices = torch.topk(zero_layer_weights[7454], k=10).indices
print(top_k_indices.numpy())
[2402 612 531 373 2933 13 3290 3049 2497 1625]
Then, if we map these indices back to their corresponding string tokens in the model’s vocabulary, we will see that these 10 column indices represent the 10 most likely tokens that follow after the word “Once”:
top_k_tokens = [ tokenizer.decode(i) for i in top_k_indices ]
print(top_k_tokens)
[' upon', ' there', ' said', ' was', ' boy', '.', ' dog', ' fast', ' saw', ' came']
In the next chapter, we will examine transformer models with Self-Attention and see how this both improves the models capabilities, and makes more difficult the task of trying to fully understand the model’s behavior.