Language Model Fine-Tuning with LoRA

This page explores the Low-Rank Adaptation, LoRA (Hu, Edward J., et al., 2021) as a method for fine-tuning pre-trained language models, and demonstrates how to apply this method using the open-source HuggingFace PEFT library.

Language Model Pretraining

Language models like BERT, BaRT, and GPT are pre-trained on vast amounts of unlabelled text data, such as the entirety of English Wikipedia, every question and answer on the StackExchange network, or The Pile (Gao, Leo, et al., 2020), a massive 825GB corpus of English text which combines the previous two datasets with data gathered from 20 other sources.

Pre-training is performed on large clusters of costly, specialized hardware like GPUs and TPUs, and can take days or weeks to complete. The table below, which shows the hardware used to pre-train a few different language models, gives some idea of the cost and scale of this process:

Model Pre-Training Cluster Estimated Daily Cost
XLM-RoBERTa[4] 500 Nvidia V100 GPUs $29,760
OPT-175B[5] 992 Nvidia A100 GPUs $93,500
LLaMA[6] 2,048 Nvidia A100 GPUs $193,000
PaLM[7] 6,144 Google v4 TPUs $475,000

Note: daily cost estimates are based on the public GPU and TPU costs published by Google Cloud as of the time of this writing.

During this pre-training process, each model is given a generalized objective to test its ability to understand language, such as predicting the next word in a sentence or predicting a randomly masked word within a piece of text.


Next Word Prediction“The cat jumped on theGiven a string of words, the modelattempts to predict what the nextword will be.bed0.120.100.040.030.030.68windowtablecarcounterwindows

Masked Language Modeling“He spilled the[MASK] on the floor”Given a string of words, with one wordmasked, the model predicts the maskedword.contents0.200.150.100.080.060.41coffeebeerdrinkwineglass

Most production applications, however, have a more specific objective, like deciding whether a review left by a customer is positive or negative, summarizing a news article, or answering a user’s question about a given piece of text.

While it is theoretically possible to train a language model from scratch (i.e. starting from randomly initialized weights) using the application-specific task as the training objective, this is rarely done in practice. Instead, a language model which has first been pre-trained on a large corpus of text is “fine-tuned” to perform the task.

Models trained using this approach (i.e. pre-training + fine-tuning) have been shown to outperform models trained from scratch on a variety of different tasks (Radford et al.). This approach makes sense intuitively as well: any task involving natural language, like sentiment analysis or summarization, requires some fundamental understanding of the language itself, irrespective of any particular context in which it is used. The purpose of pre-training is to instill in the model this fundamental understanding.

Fine-Tuning

With this brief overview of pre-training aside, let’s now to fine-tuning. As already mentioned, when fine-tuning, we begin with a set of weights learned during pre-training, then adapt these weights to an application-specific task by training the model on a small (relative to the pre-training dataset) amount of labeled data.

For example, the code below uses the pre-trained weights of the BERT model and adapts them to the task of sentiment analysis by fine-tuning the model on the TweetEval dataset, which contains tweets labeled as either positive, negative, or neutral.

model = BertForSequenceClassification.from_pretrained(
    'bert-base-cased', 
    num_labels=3
)

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

train_dataset, test_dataset = load_tweet_eval_dataset(tokenizer)

tweet_eval_dataset = load_dataset('tweet_eval', name='sentiment')

tweet_eval_dataset = tweet_eval_dataset.map(
    lambda sample : tokenizer(sample['text'], padding='max_length')
)

train_dataset = tweet_eval_dataset['train'].with_format('torch')
test_dataset = tweet_eval_dataset['test'].with_format('torch')

trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)

trainer.train()

The above code is just a minimal demonstration. The full fine-tuning script is available here: https://github.com/gnovack/fine-tuning-with-lora

After Fine-Tuning BERT for a single epoch on the ~45,000 example tweets in the TweetEval (which takes around 44 minutes to complete on my MacBook), the model can classify a tweet as positive, negative, or neutral with 69% precision. Compare this to the 56% precision achieved by training BERT from scratch (i.e. starting from random weights) on the same dataset.

Drawbacks of Fine-Tuning

Fine-Tuning BERT for a specific task is cheap and effective. In the world of language models, however, BERT is on the small side with around 110M trainable parameters. Many pre-trained language models have billions or even hundreds of billions of parameters:

When fine-tuning BERT, we had to store within memory not only each of the model’s weights, but also its gradient and associated optimizer state during each training step. This means that a huge amount of memory is required to fine-tune much larger models in this way: for example, around 1.2TB of memory is required to fine-tune the 175B parameter variant of GPT-3[1].

Partial Fine-Tuning

To reduce these memory requirements, one simple modification we can make to the fine-tuning procedure is to adapt only a subset of the model’s weights. As transformer models are typically composed of individual encoder and/or decoder layers chained together in sequence, a common approach is to only retrain the weights of the final one or two layers of the model.

For example, the code below adds to the previous BERT fine-tuning code, and demonstrates how to retrain only the top one or two encoder layers by freezing all other model weights:

# freeze all but the final encoder
if args.fine_tuning == 'top':
    for param in model.bert.parameters():
        param.requires_grad = False
    
    for param in model.bert.encoder.layer[-1].parameters():
        param.requires_grad = True

# freeze all but the last two encoders
elif args.fine_tuning == 'top2':
    for param in model.bert.parameters():
        param.requires_grad = False
    
    for param in model.bert.encoder.layer[-1].parameters():
        param.requires_grad = True
    
    for param in model.bert.encoder.layer[-2].parameters():
        param.requires_grad = True

As we can see from the results above, the “partially” fine-tuned models perform slightly worse compared the fully fine-tuned model, with the Top-2 model outperforming the Top-1 model.

The Top-2 and Top-1 models can be fine-tuned much faster, however, compared to the fully fine-tuned model:

Time taken to fine-tune BERT on the TweetEval dataset for 1 epoch, measured in minutes.

Understanding LoRA

The fine-tuning techniques shown so far all involve retraining some or all of the model’s pre-trained weights to adapt the model for some downstream task. Low-Rank Adaptation (LoRA) is a decidedly different approach to fine-tuning introduced by (Hu, Edward J., et al., 2021).

LoRA was inspired by the observation that pre-trained language models have a low “intrinsic dimensionality” (Aghajanyan et al., 2020), meaning that the number of parameters required to effectively perform a given task (such as sentiment analysis of tweets) is much smaller than the total number of parameters in the model.

What follows from this observation is that, in order to fine-tune any particular layer in a pre-trained model, we do not need to modify every one of that layer’s weights. We can instead learn a lower-dimensional, task-specific representation of the layer’s weights.

Specifically, LoRA works by injecting these lower-dimensional representations at points in pre-trained model graph.

As an example, consider a typical fully-connected (i.e. Linear) layer with \(m\) input units and \(n\) output units. The weights of this layer can be described by the matrix \(W\) of shape \(m \times n\). Given an input \(x\), the output of the layer is given by:

$$y = Wx$$

When fine-tuning with LoRA, the weights in \(W\) are frozen, and two new trainable matrices, \(A\) and \(B\), are introduced. And the output of the layer is then given by:

$$y = Wx + ABx$$

At first glance, it is not obvious how the introduction of two new weight matrices could improve the efficiency of fine-tuning. If anything, it seems we have only made our layer more complicated, but assigning some concrete numbers to the equation makes the benefit of this change more evident.

Imagine that \( m = 768 \) and \( n = 3072 \), so that the weight matrix \(W\) has shape \(768 \times 3072 \), which amounts to \(2{,}359{,}296\) unique weights. Following LoRA, we add two new matrices, \(A\) and \(B\), but what are the shapes of these matrices? Based on the equation above, we know that the shape of \(AB\) must match the shape of \(W\) so that we can add \(Wx\) to \(ABx\) when computing the output of the layer.

This means that \(A\) must have the shape \(768 \times r\), and \(B\) must have the shape \(r \times 3072\), where \(r\) can be any number we choose, as it does not affect the shape of \(AB\). In the case of \(r\ = 1\), the total number of weights in \(A\) and \(B\) together is \( (768 \times 1) + (1 \times 3072) = 3{,}840\).

If we were to fine-tune this layer using the approach shown earlier, we would be retraining all \(2{,}359{,}296\) weights in \(W\). With LoRA, however, we need only retrain the \(3{,}840\) weights in \(A\) and \(B\).

Using LoRA with HuggingFace PEFT

With this conceptual understanding of LoRA, let’s see how to use LoRA with HuggingFace’s Parameter-Efficient Fine-Tuning (PEFT) library. We start by creating a new LoraConfig object:

from peft import LoraConfig, TaskType

lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS, r=1, lora_alpha=1, lora_dropout=0.1
)

The task_type parameter specifies the type of task for which we are fine-tuning the model, r specifies the dimension in \(A\) and \(B\) called \(r\) above, and lora_alpha is a scaling factor that controls the relative importance of the weights in \(A\) and \(B\) compared to the original parameters in the model: before \(Wx\) and \(ABx\) are summed, \(BAx\) is scaled by \( \frac{\alpha}{r} \), so by setting \(\alpha = r\), we disable this scaling.

To inject the new \(A\) and \(B\) matrices into our model, we call the get_peft_model function:

from peft import get_peft_model
model = get_peft_model(model, lora_config)

By default, when calling get_peft_model on a BERT model, \(A\) and \(B\) matrices will be injected into the query and value projections (i.e. the Linear layers which transform the input embeddings into query and value vectors as part of the self-attention mechanism) in each encoder layer.

We can confirm this by examining the self-attention module in any one of the encoder layers:

print(model.bert.encoder.layer[0].attention.self)
BertSelfAttention(
  (query): Linear(
    in_features=768, out_features=768, bias=True
    (lora_dropout): ModuleDict(
      (default): Dropout(p=0.1, inplace=False)
    )
    (lora_A): ModuleDict(
      (default): Linear(in_features=768, out_features=1, bias=False)
    )
    (lora_B): ModuleDict(
      (default): Linear(in_features=1, out_features=768, bias=False)
    )
    (lora_embedding_A): ParameterDict()
    (lora_embedding_B): ParameterDict()
  )
  (key): Linear(in_features=768, out_features=768, bias=True)
  (value): Linear(
    in_features=768, out_features=768, bias=True
    (lora_dropout): ModuleDict(
      (default): Dropout(p=0.1, inplace=False)
    )
    (lora_A): ModuleDict(
      (default): Linear(in_features=768, out_features=1, bias=False)
    )
    (lora_B): ModuleDict(
      (default): Linear(in_features=1, out_features=768, bias=False)
    )
    (lora_embedding_A): ParameterDict()
    (lora_embedding_B): ParameterDict()
  )
  (dropout): Dropout(p=0.1, inplace=False)
)

The Linear layers query and value now have lora_A and lora_B weights as attributes. Notice that the value provided for r when creating the LoraConfig is the same as the number of out_features of lora_A as well as the number of in_features to lora_B

We can also see how many new trainable weights were added to the model:

model.print_trainable_parameters()
trainable params: 36864 || all params: 108377668 || trainable%: 0.03401438753969129

With \(r = 1\), the total number of new trainable weights is 36,864, because the model container 12 encoder layers, 2 pairs of \(A\) and \(B\) matrices per layer (one for the query projection and one for the value projection), and \((768 \times 1) + (1 \times 768) = 1{,}536\) weights per pair of (A\) and \(B\) matrices: \(12 \times 2 \times 1536 = 36{,}864\)

Experiments with LoRA

BERT Sentiment Analysis

Similar to the fine-tuning experiments described above, I fine-tuned BERT for sentiment analysis using LoRA with various values of \(r\). As shown below, the model fine-tuned using LoRA with \(r = 8\) outperforms both the Top-1 and Top-2 encoder models. These results are impressive considering that the LoRA model has only ~300,000 trainable weights, while the Top-2 and Top-1 models have ~14 million and ~7 million, respectively.

Interestingly, increasing the value of \(r\) does not necessarily lead to better performance. In terms of both precision and F1 score, the model fine-tuned with \(r = 8\) narrowly outperforms models with \(r = 16\) and \(r = 32\).

r=1 r=2 r=4 r=8 r=16 r=32
Precision 0.675 0.675 0.675 0.677 0.674 0.677
F1 Score 0.673 0.679 0.678 0.682 0.680 0.681

Saving and Loading LoRA Models

Once fine-tuning is complete, we typically save the entire fine-tuned model to disk for later use. However, when using LoRA, we only need to save the new weights which were added to the model (i.e. the \(A\) and \(B\) matrices), because the original pre-trained model weights remain unchanged.

This is handled automatically when we call save_pretrained() on an instance of PeftModel (i.e. the model object returned by get_peft_model):

output_model_path = "./bert-sentiment-analysis-lora/"
model.save_pretrained(output_model_path)

# check the size of the saved model
for file in os.listdir(output_model_path):
    file_size = os.path.getsize(output_model_path + file)
    print(f"File: {file}; Size: {file_size / 1024:.2f}KB")
File: adapter_config.json; Size: 0.32KB
File: adapter_model.bin; Size: 1178.20KB

The fine-tuned BERT model weights take up only ~1000KB of disk space, much less than the ~430MB of space occupied by the entire BERT model. This highlights another interesting benefit of LoRA: it keeps fine-tuned model weights small, making them cheap to store and easy to share, which could be especially useful when managing a large number of different fine-tuned models.

When we are ready to load the model for inference, we first load the pre-trained model, and then inject the fine-tuned LoRA weights by calling PeftModel.from_pretrained(), passing in the pre-trained model and the path to the saved LoRA weights:

from transformers import BertForSequenceClassification
from peft import PeftModel

pretrained_model = BertForSequenceClassification.from_pretrained(
    "bert-base-cased", num_labels=3
)

model = PeftModel.from_pretrained(pretrained_model, "./bert-sentiment-analysis-lora/")

One detail worth noting is that, when the model is loaded for inference, the LoRA weights and pre-trained weights for each layer are added together using the formula:

$$W = W + AB$$

In this way, no additional latency is introduced by fine-tuning with LoRA. We can confirm this by comparing the number of parameters in the pre-trained model to the number of parameters in the fine-tuned model:

pretrained_model_parameters = sum([p.numel() for p in pretrained_model.parameters()])
print(f"Number of pre-trained model parameters: {pretrained_model_parameters:,}")

model_parameters = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters after LoRA: {pretrained_model_parameters:,}")
Number of pre-trained model parameters: 108,904,710
Number of model parameters after LoRA: 108,904,710

Platonic OPT

To further explore how the value of the rank \(r\) impacts model performance (and to have a little fun), I decided to apply LoRA to text generation using the 1.3B parameter OPT model (Zhang, Susan, et al., 2022).

The model was fine-tuned using LoRA for 3 epochs on 13 different works by Plato gathered from the Project Gutenberg Online Library, and then evaluated on Plato’s Apology, which was not included in the training set.

The evaluation loss at various values of \(r\) is shown in the table below:

r=1 r=2 r=4 r=8 r=16 r=32 r=64 r=128
Eval Loss 2.524 2.496 2.481 2.472 2.463 2.456 2.455 2.456

The fine-tuning script for OPT is available here: https://github.com/gnovack/fine-tuning-with-lora

Here, the evaluation loss decreases slightly from \(r=1\) to \(r=64\), which may indicate that this task has higher intrinsic dimensionality (i.e. that it requires a larger number of parameters to effectively model) compared to the sentiment analysis task. Alternatively, it may be that the optimal value of \(r\) depends in some way on the size of the model. Some further experiments would be required to reach a conclusion here.

Questions and Future Work

This look into language model fine-tuning with LoRA has left me with a few open questions and future ideas to explore:

  1. How does the optimal value of \(r\) vary with the size of the model (i.e. given the exact same evaluation criteria)?

  2. Given a fixed number of LoRA weights to add to a pre-trained model, what is the optimal placement of these weights? Do they need to be added evenly to each encoder/decoder layer, or could they be added to only the last \(N\) layers, thus improving memory efficiency during fine-tuning? To my knowledge, the original LoRA paper does not explore this question.

  3. What is the effect of adding LoRA weights to the feed-forward layers of the encoder/decoder layers?

  4. How much additional latency would be introduced by keeping the LoRA weights separate from the pre-trained model weights during inference? This could help determine the practicality of an inference system which uses a single pre-trained model with the ability to quickly “context switch” between different fine-tuned LoRA weights.

But for now, I will conclude with some wise words from Platonic OPT, contrasted with the rather cynical take of the base OPT-1.3B model:

Platonic OPT


Base OPT

References

[1] Hu, Edward J., et al. “Lora: Low-rank adaptation of large language models.” arXiv preprint arXiv:2106.09685 (2021).

[2] Alec Radford, Karthik Narasimhan, Tim Salimans, and Ilya Sutskever. Improving Language Understanding by Generative Pre-Training.

[3] Gao, Leo, et al. “The pile: An 800gb dataset of diverse text for language modeling.” arXiv preprint arXiv:2101.00027 (2020)

[4] Conneau, Alexis, et al. “Unsupervised cross-lingual representation learning at scale.” arXiv preprint arXiv:1911.02116 (2019).

[5] Zhang, Susan, et al. “Opt: Open pre-trained transformer language models.” arXiv preprint arXiv:2205.01068 (2022).

[6] Touvron, Hugo, et al. “Llama: Open and efficient foundation language models.” arXiv preprint arXiv:2302.13971 (2023).

[7] Chowdhery, Aakanksha, et al. “Palm: Scaling language modeling with pathways.” arXiv preprint arXiv:2204.02311 (2022).

[8] Aghajanyan, Armen, Luke Zettlemoyer, and Sonal Gupta. “Intrinsic dimensionality explains the effectiveness of language model fine-tuning.” arXiv preprint arXiv:2012.13255 (2020).