Distributed Training and DeepSpeed

This page explores the techniques used to distribute neural network training across multiple GPUs, and examines the various distributed training optimizations offered by Microsoft’s DeepSpeed library.

The Memory Requirements of Training

Much of the focus of this page will be on techniques to reduce the amount of memory required to train neural networks. So, to better understand and appreciate these techniques, I will start by breaking down all of the ways in which memory is consumed during the training process.

Following Sohoni et al., I will use the terms below to describe the sources of memory consumption during training (Sohoni et al. describe only three sources, but I have separated Gradient Memory and Optimizer Memory, because this same distinction is made in DeepSpeed, as will be shown later):

  • Model Memory: The memory required to store the model’s weights.

  • Gradient Memory: The memory required to store the gradients of each model weight.

  • Optimizer Memory: The memory required to store any additional state required by the optimizer.

  • Activation Memory: The memory required to store the intermediate values computed during the forward pass.

Measuring the Four Sources of Memory Consumption

Let’s look now at how to measure the amount of memory consumed by each of these four sources, starting with a simple densely-connected neural network:

model = torch.nn.Sequential(
    torch.nn.Linear(512, 1024),
    torch.nn.Linear(1024, 1024),
    torch.nn.Linear(1024, 1024),
    torch.nn.Linear(1024, 512)
)

To get the total Model Memory, we count the number of parameters in the model, and multiply by the number of bytes used to represent each parameter (in this case, 4 bytes, because each weight is a 32-bit floating point number):

def get_model_memory(model: torch.nn.Module):
    """
    Returns the memory usage of the given model
    """
    total_memory = 0
    for param in model.parameters():
        total_memory += param.numel() * param.element_size()
    return total_memory

print("Model Memory: {:,} bytes".format(get_model_memory(model)))
Model Memory: 12,597,248 Bytes

The Gradient Memory is the same as the Model Memory, as we need to store the gradient for each weight during the backward pass.

The total amount of Optimizer Memory consumed depends on the type of optimizer used during training. The table below shows the Optimizer Memory required some of the popular optimizers available in PyTorch, where \(M_{model}\) represents the Model Memory[2]:

Optimizer Description Memory Requirement
SGD Stochastic Gradient Descent $$0$$
SGD Stochastic Gradient Descent w/ Momentum $$M_{model}$$
Adam Adam optimizer requires us to maintain the first and second moment for each gradient (4 bytes each). $$2 \times M_{model}$$

Finally, to compute the total Activation Memory, we need to count the hidden layer outputs computed during the forward pass. There are a few different ways to do this; here, we will recursively inject a PyTorch Forward Hook into every submodule to capture the size of that submodule’s output.

First, we define the hook, along with a simple helper class, ActivationCounter, used to keep track of the total size of the model’s hidden layer outputs:

class ActivationCounter:

    def __init__(self):
        self.activation_bytes = 0

    def add_activations(self, tensor):
        self.activation_bytes += tensor.numel() * tensor.element_size()

def activation_counter_hook(counter: ActivationCounter):

    def hook(self, input, output):
        counter.add_activations(output.data)

    return hook

Next, we define a function that will recursively inject the hook into every submodule of the model:

def register_hooks_recursive(model, counter: ActivationCounter):
  for module in model.children():
      module.register_forward_hook(activation_counter_hook(counter))
      register_hooks_recursive(module, counter)

activation_counter = ActivationCounter()
register_hooks_recursive(model, activation_counter)

Now to find the total amount of Activation Memory consumed, we execute a forward pass through the model and then print the value of activation_counter.activation_bytes:

inputs = torch.randn(4, 512)
outputs = model(inputs)

# because the hooks only capture layer outputs, we need to add
# the size of the original input tensor separately
activation_counter.add_activations(inputs)

print("Activation Memory: {:,} bytes".format(
  activation_counter.activation_bytes
))
Activation Memory: 65,536 bytes

The complete script to calculate model memory and activation memory for dense neural networks can be found in the GitHub repository: estimate_nn_memory.py

Memory Requirements of Transformers

Because Rajbhandari, Samyam, et al. (in their paper which introduced many of the training optimizations implemented in DeepSpeed) focused their experiments on Transformer models, I wanted to better understand the memory requirements of this class of models.

Using the same approach as above, I computed the Model and Activation Memory for a basic Transformer model (see Appendix A for details) as a way to validate my formula for estimating Transformer model memory, and the formula for Transformer activation memory provided by Korthikanti, Vijay Anand, et al (both are shown below):

$$\begin{aligned} M_{model} &= 4nh \times (13 + 12h) \\ M_{activation} &= nblh \times (67 + \frac{9ml}{h}) \end{aligned}$$

Given the following:

$$\begin{aligned} n &: \text{number of layers} \\ h &: \text{hidden size} \\ m &: \text{number of attention heads} \\ b &: \text{batch size} \\ l &: \text{sequence length} \end{aligned}$$

Demo: Estimating Transformer Memory

Distributed Training

The optimizations offered by DeepSpeed are primarily applicable to distributed training environments (e.g. when multiple GPUs are used to train a single model). There are three common paradigms for parallelizing training with multiple GPUs: Data Parallelism, Model Parallelism, and Pipeline Parallelism. Each is explained briefly in the following sections.

Data Parallelism

Data Parallel (DP) describes a distributed training process where the entire model is replicated on multiple devices (e.g. GPUs), and each device performs training in parallel on different batches of input data. There are multiple different ways to implement Data Parallel training, but here I will only describe the implementation used by the PyTorch DistributedDataParallel module[4].

GPU 0GPU 1TrainingDataBatchesGPU 2

Each device receives a different batch of input data and performs a forward pass to compute the loss on that batch.

As gradients are computed during the backward pass on each device, these gradients are exchanged with all other devices. The average of the gradients is then used to update the model weights on each device, ensuring that, at the beginning of the next training step, all devices have the same set of model weights.

GPU 0GPU 1GradientsGPU 2

AllReduce

The exchange of gradients between devices is performed using a collective communication algorithm called AllReduce. The AllReduce algorithm executes a reduction operation on data that is distributed across multiple devices.

The visualization below illustrates the execution of an AllReduce operation to compute the sum of a vector of values distributed across 3 devices:

An illustration of Ring-AllReduce[7], one popular implementation of the AllReduce operation.

PyTorch provides support for several collective communication algorithms, including AllReduce, via its torch.distributed module. The script below demonstrates an AllReduce operation across 3 GPUs:

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def create_process_group(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    dist.init_process_group(
        backend='nccl',
        world_size=world_size,
        rank=rank
    )

def all_reduce_example(rank, world_size):
    create_process_group(rank, world_size)

    # create a different tensor on each device
    if rank == 0:
        tensor = torch.tensor([1, 2, 3]).to(rank)
    elif rank == 1:
        tensor = torch.tensor([10, 20, 30]).to(rank)
    elif rank == 2:
        tensor = torch.tensor([4, 5, 6]).to(rank)

    print('Before AllReduce: Rank ', rank, ' has data ', tensor)
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    print('After AllReduce: Rank ', rank, ' has data ', tensor)

if __name__ == "__main__":
    device_count = 3
    mp.spawn(all_reduce_example, args=(device_count,), nprocs=device_count)
Before AllReduce: Rank  0  has data  tensor([1, 2, 3], device='cuda:0')
Before AllReduce: Rank  1  has data  tensor([10, 20, 30], device='cuda:1')
Before AllReduce: Rank  2  has data  tensor([4, 5, 6], device='cuda:2')
After AllReduce:  Rank  2  has data  tensor([15, 27, 39], device='cuda:2')
After AllReduce:  Rank  0  has data  tensor([15, 27, 39], device='cuda:0')
After AllReduce:  Rank  1  has data  tensor([15, 27, 39], device='cuda:1')

A Few Notes on Data Parallelism

Importantly, Data Parallelism does not reduce the total memory required for training, as each device must have enough memory to perform a forward and backward pass on the entire model.

The main benefit, then, of data parallelism is that it speeds up the training process, as each device only needs to train on a fraction of the total training data. The speedup, however, does not scale linearly with the number of devices, as there is communication overhead introduced by the exchange of gradients between devices. For example, Li, Shen, et al. 2020 observe that, when training a BERT model across 256 GPUs, the scaling factor is roughly 128.

The chart below shows the time taken to train a BERT model for 1,000 steps using data parallel training across 1-8 GPUs:

The data parallel training script used to generate these results is available at distributed-training-and-deepspeed/data_parallel_training.py

Model Parallelism

Model Parallel (MP) describes a distributed training process where the model is partitioned across multiple devices, such that each device contains only part of the model’s weights. The forward pass is executed sequentially on each device, with the output of one device becoming the input to the next.

GPU 0GPU 1TrainingDataBatchesGPU 2

Model parallelism is essential when training very large models which cannot fit on a single device.

One major downside of model parallel training is that the sequential nature of the forward and backward passes leads to idle time on each device. For example, in the diagram above, after GPU 0 has completed its forward pass on an input batch, it sits idle while GPUs 1 and 2 perform their forward and backward passes.

By using PyTorch forward and backward hooks, we can measure the amount of time that each GPU is idle during training:

def idle_time_hook(self, device, forward=True, entering=True):
    """Creates a PyTorch hook which logs the idle time of a device."""
    
    def hook(*args, **kwargs):
        current_timestamp = time.time()
        last_timestamp = self.previous_timestamp.get(device, None)

        message = "{} {} pass on device {}".format(
            "Entering" if entering else "Finished",
            "forward" if forward else "backward",
            device
        )

        if entering and last_timestamp is not None:
            idle_time_ms = (current_timestamp - last_timestamp) * 1000
            self.device_idle_time[device] = (
                self.device_idle_time[device][0] + idle_time_ms,
                self.device_idle_time[device][1] + 1
            )
            
            message += f". Idle time: {idle_time_ms:.2f}ms"

        self.previous_timestamp[device] = current_timestamp
        self.log(message)
    return hook
2023-06-04 00:20:36.786470 - Entering forward pass on device 0.
2023-06-04 00:20:36.789441 - Finished forward pass on device 0
2023-06-04 00:20:36.790807 - Entering forward pass on device 1.
2023-06-04 00:20:36.793664 - Finished forward pass on device 1
2023-06-04 00:20:36.794982 - Entering forward pass on device 2.
2023-06-04 00:20:36.796839 - Finished forward pass on device 2
2023-06-04 00:20:36.798351 - Entering forward pass on device 3.
2023-06-04 00:20:36.799569 - Finished forward pass on device 3
2023-06-04 00:20:37.911536 - Entering backward pass on device 3. Idle time: 1111.96ms
2023-06-04 00:20:37.913496 - Finished backward pass on device 3
2023-06-04 00:20:37.914812 - Entering backward pass on device 2. Idle time: 1117.97ms
2023-06-04 00:20:37.916677 - Finished backward pass on device 2
2023-06-04 00:20:37.918056 - Entering backward pass on device 1. Idle time: 1124.39ms
2023-06-04 00:20:37.920743 - Finished backward pass on device 1
2023-06-04 00:20:37.922119 - Entering backward pass on device 0. Idle time: 1132.67ms
2023-06-04 00:20:37.923938 - Finished backward pass on device 0

The timeline below illustrates the forward and backward execution on each GPU when performing model parallel training across three GPUs. The gray space in each timeline represents the time that that GPU is idle.

Pipeline Parallelism

Pipeline Parallelism (PP) is a variation on model parallelism that reduces the amount of device idle time by splitting each batch of input data into a number of smaller “micro-batches”[6]. Model parameters are only updated after every micro-batch has been processed by the entire model, meaning that each device can begin processing the next micro-batch while the other devices are still processing the previous one.

Support for pipeline parallelism is built into PyTorch via the torch.distributed.pipeline.sync.Pipe class. Two major limitations of this class, however, are that (1) it only works when the model is implemented as a torch.nn.Sequential module, and (2) it requires that the inputs and outputs of each module be either a single tensor or a tuple of tensors[7].

Because of these limitations, I had to modify the BERT model implementation from the HuggingFace Transformers library to support pipeline parallelism. The modified model can be found here: distributed-training-and-deepspeed/model/bert_mp.py

Using this custom BERT implementation, we can enable pipeline parallelism by first converting the model into a torch.nn.Sequential module, and then wrapping it in a torch.distributed.pipeline.sync.Pipe object:

def to_pipeline(self, chunks):
    """Convert the model for pipeline parallelism."""
    rpc.init_rpc(
        name="worker",
        rank=0,
        world_size=1,
        rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
            init_method="file://{}".format(tempfile.NamedTemporaryFile().name)
        )
    )

    sequential = torch.nn.Sequential(
        self.embeddings,
        *self.encoders, 
        self.head
    )
    return Pipe(sequential, chunks=chunks)

Comparing Model Parallelism and Pipeline Parallelism

To compare the performance of model parallelism and pipeline parallelism, I executed a few training runs using my custom BERT implementation. The figure below shows the time taken to perform 250 training steps on 2, 4, 6, and 8 GPUs using both model parallelism and pipeline parallelism. For these tests, batch size was set to 16, and, for pipeline parallelism, the number of micro-batches was set to 4.

DeepSpeed

With this understanding of the basic landscape of distributed training techniques, we can now begin to look at DeepSpeed.

DeepSpeed is an open-source library of deep learning optimization tools created by Microsoft, which spans the realms of model training, inference, and model compression. Here, I will focus only on the training-related optimizations and will leave the topics of inference and compression to future exploration.

As discussed earlier, we can perform distributed training by replicating the entire model on multiple devices (Data Parallelism), and/or by splitting the model up and storing different parts of it on different devices (Model Parallelism / Pipeline Parallelism). In general, data parallelism is more computationally efficient compared to model parallelism; however, in cases where the model is too large to fit into the available memory of a single device, model parallelism is necessary.

Central to the training optimizations provided by DeepSpeed is the Zero Redundancy Optimizer (ZeRO), a set of techniques to reduce the amount of memory required for distributed model training.

ZeRO

Data parallelism introduces significant memory redundancy across devices. For example, consider training a 1 billion parameter model across 8 GPUs. The memory required to store the model parameters alone (assuming 32-bit floats) is ~3.7GB. When using data parallelism, each device must store a copy of the entire model, meaning that, across all devices, a total of ~29.8GB of memory is used. In other words, 26.1GB of device memory is occupied by redundant model parameters. Activation and Optimizer memory are similarly redundant across devices.

ZeRO-DP addresses this redundancy by:

  1. Partitioning optimizer state
  2. Partitioning gradients
  3. Partitioning model parameters

When using the deepspeed library, we select which of these memory optimizations to apply by specifying the ZeRO Stage. The available stages are:

Stage 0 No memory optimizations are applied.
Stage 1 Optimizer states are partitioned across devices.
Stage 2 Optimizer states and gradients are partitioned across devices.
Stage 3 Optimizer states, gradients, and model parameters are partitioned across devices.

In the following sections, we will explore each of these stages in detail. All of the example scripts have been tested on a single machine with 8 16GB NVIDIA V100 GPUs.

Stage 0 - Getting Started

The DeepSpeed library is installed like any other Python package:

pip install deepspeed

When executing data parallel training with DeepSpeed, we do not have to set up process groups or explicitly spawn multiple processes as we do when using PyTorch’s DistributedDataParallel. Instead, we wrap our model in a DeepSpeedEngine by calling deepspeed.initialize, which handles all of the distributed training logic internally.

For these examples, we will use the pre-trained, 560M parameter variant of the BLOOM model[9] from HuggingFace.

# deepspeed_stage_0.py
#
# DeepSpeed automatically sets the LOCAL_RANK environment variable
# to the index of the current device.
rank = int(os.getenv("LOCAL_RANK", "0"))

model = BloomForCausalLM.from_pretrained("bigscience/bloom-560m")

deepspeed_config = {
    "train_micro_batch_size_per_gpu": 1
}

model_engine, _, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config=deepspeed_config
)

print(f"Device {rank} - ZeRO Stage: {model_engine.zero_optimization_stage()}")

To launch a training distributed job, we use the deepspeed command line utility, which is installed alongside the deepspeed Python package:

deepspeed deepspeed_stage_0.py
Device 1 - ZeRO Stage: 0
Device 4 - ZeRO Stage: 0
Device 2 - ZeRO Stage: 0
Device 3 - ZeRO Stage: 0
Device 6 - ZeRO Stage: 0
Device 7 - ZeRO Stage: 0
Device 0 - ZeRO Stage: 0
Device 5 - ZeRO Stage: 0

Here we see that the model is replicated across all four devices. To allow DeepSpeed to create and manage the optimizer used during training, we add the appropriate configuration options to the deepspeed_config object:

deepspeed_config = {
    "train_micro_batch_size_per_gpu": 1,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 5e-5
        }
    },
}

model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config=deepspeed_config
)

# Print optimizer state information
optimizer_state = optimizer.param_groups[0]
print(f"Device {rank} - Optimizer: lr={optimizer_state['lr']}; "
      f"betas={optimizer_state['betas']}; eps={optimizer_state['eps']}; "
      f"parameter count={sum([torch.numel(p) for p in optimizer_state['params']]):,}")
Device 1 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=559,214,592
Device 4 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=559,214,592
Device 2 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=559,214,592
Device 3 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=559,214,592
Device 6 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=559,214,592
Device 7 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=559,214,592
Device 0 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=559,214,592
Device 5 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=559,214,592

Training a DeepSpeedEngine is much like training a PyTorch model with data parallelism. The only difference is that the backward() and step() methods are called directly on the DeepSpeedEngine object rather than on the loss or the optimizer:

# DeepSpeed automatically sets the WORLD_SIZE environment variable
# to the number of devices participating in the training job.
world_size = int(os.getenv("WORLD_SIZE", "1"))

tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

# helper function to load the wikitext dataset
# implementation can be found here:
# https://github.com/gnovack/distributed-training-and-deepspeed/blob/main/util.py
train_dataset = load_wikitext(tokenizer, collator).select(range(64))

train_dataloader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=False,
    sampler=DistributedSampler(train_dataset, num_replicas=world_size)
)

for batch in train_dataloader:
    device = torch.device("cuda", rank)
    
    input_ids = batch['input_ids'].to(device)
    labels = batch['labels'].to(device)

    outputs = model_engine(input_ids, labels=labels)
    model_engine.backward(outputs.loss)
    model_engine.step()

Stage 1 - Optimizer State Partitioning

ZeRO Stage 1 partitions the optimizer state (for example, when using Adam, the first and second moment values) across all devices, so that each device contains only a portion of this state.

The figure below illustrates the per-device memory allocation when training the BLOOM model with traditional data parallelism (top), and ZeRO Stage 1 (bottom):

GPU Memory Allocation - Traditional Data Parallelism

GPU 0GPU 1GPU 2GPU 3ModelParametersGradientsOptimizer StateActivations

GPU Memory Allocation - ZeRO Stage 1

GPU 0GPU 1GPU 2GPU 3ModelParametersGradientsOptimizer StateActivations

Based on the formula for transformer model memory discussed above, we can estimate that the memory required to store the model parameters of a 560M parameter model is around 2.24GB. When using Adam in 32-bit mode, the optimizer states take up 8 bytes of memory for each model parameter, so 4.48GB will be required by the optimizer.

Assuming that we train the model with batch size 1 and input sequence length 512, we can use the formula above to estimate the amount of memory required by activations at 1.75GB.

Thus, the total amount of memory required to train the BLOOM 560M model using traditional data parallelism is roughly:

ModelParametersGradientsActivationsOptimizerStates

The deepspeed library includes several utility functions we can use to analyze memory utilization during training. For example, to measure the current and peak memory utilization, we call the memory_status() function within our training loop:

from deepspeed.runtime.utils import memory_status

for batch in train_dataloader:
    input_ids = batch['input_ids'].to(device)
    labels = batch['labels'].to(device)

    outputs = model_engine(input_ids, labels=labels)
    
    model_engine.backward(outputs.loss)

    model_engine.step()

    if rank == 0:
        memory_status("Memory stats after training step")

If we execute the training script now using the deepspeed CLI, we will see the following line printed after each step:

RANK=0 MEMSTATS Memory stats after training step device=cuda:0 current alloc=6.8847GB (delta=0.0000GB max=10.8823GB) current cache=13.4453GB (delta=0.0000GB max=13.4453GB)

As shown, the peak GPU memory utilization is 10.88GB, very close to the 10.71GB we estimated. Next, let’s see how much memory we can save by using ZeRO Stage 1. To enable ZeRO Stage 1, we update the deepspeed_config object to include the zero_optimization configuration option:

deepspeed_config = {
    "train_micro_batch_size_per_gpu": 1,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 5e-5
        }
    },
    "zero_optimization": {
        "stage": 1,
    }
}

model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config=deepspeed_config
)

We will notice a few differences now when executing the training script. First, when printing the optimizer parameter count on each device, the number of parameters shown is now 69,901,824, or \( \frac{1}{8} \) of the total number of parameters in the model:

Device 0 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=69,901,824
Device 6 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=69,901,824
Device 7 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=69,901,824
Device 4 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=69,901,824
Device 5 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=69,901,824
Device 3 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=69,901,824
Device 2 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=69,901,824
Device 1 - Optimizer: lr=5e-05; betas=(0.9, 0.999); eps=1e-08; parameter count=69,901,824

Second, when printing the memory utilization after each training step, the peak GPU memory utilization is now 7.48GB, 3.4GB less than the 10.88GB we measured when training with traditional data parallelism:

RANK=0 MEMSTATS Memory stats after training step: device=cuda:0 current alloc=3.4843GB (delta=0.0000GB max=7.4816GB) current cache=12.7695GB (delta=0.0000GB max=12.7695GB)

So where does this extra 3.4GB come from? Recall that the optimizer states take up around 4.48GB. Previously, when using traditional data parallelism, each device held a copy of all 4.48GB worth of optimizer states. Now, with ZeRO Stage 1, each device only holds \( \frac{1}{8} \) of the optimizer states, or 0.56GB.

In theory, this means that the peak GPU memory when using ZeRO stage 1 should be around:

$$2.24 + 2.24 + 0.56 + 1.75 = 6.79GB$$

The fact that we measured 7.48GB instead of 6.79GB can possibly be attributed to the use of intermediate buffers for communication of updated model weights, but some further investigation will be required to confirm this.

Stage 2 – Gradient Partitioning

ZeRO Stage 2 takes a step further by partitioning both the optimizer states and the gradients across devices. The figure below illustrates the per-device memory allocation when training with ZeRO Stage 2:

GPU Memory Allocation - ZeRO Stage 2

GPU 0GPU 1GPU 2GPU 3ModelParametersGradientsOptimizer StateActivations

Stage 2 is similarly enabled by setting the stage option in the deepseed_config:

deepspeed_config = {
    "train_micro_batch_size_per_gpu": 1,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 5e-5
        }
    },
    "zero_optimization": {
        "stage": 2
    }
}

If we execute the training script with ZeRO Stage 2 enabled, we will find that the peak GPU memory utilization has increased slightly to 8.22GB, compared to 7.48GB when training with Stage 1:

RANK=0 MEMSTATS Memory stats after training step: device=cuda:0 current alloc=3.4843GB (delta=0.0000GB max=8.2187GB) current cache=12.7715GB (delta=0.0000GB max=12.7715GB)

To understand the reason for this unexpected increase in peak GPU memory, we need to know a little about the implementation of ZeRO Stage 2. During the backward pass, gradients are averaged and placed on one of the devices through a series of Reduce operations.

While we could perform a Reduce operation after computing each and every gradient, we can achieve better performance by reducing the gradients in chunks.

The size of these chunks is controlled by the reduce_bucket_size option in the deepspeed_config, and defaults to \( 5 \times 10^8 \) elements, or 2GB when each element is a 32-bit float. While this default value is suitable for much larger models, it is too large for our 560M model, the gradients for which take up only about 2.24GB in total.

If we decrease the reduce_bucket_size to \( 5 \times 10^6 \) elements, we can decrease peak memory utilization to 6.37GB, an additional 1.11GB reduction compared to Stage 1:

RANK=0 MEMSTATS Memory stats after training step: device=cuda:0 current alloc=3.4843GB (delta=0.0000GB max=6.3750GB) current cache=10.9277GB (delta=0.0000GB max=10.9277GB)

It is worth considering whether the peak memory of 6.37GB aligns with our expectations. Given that each device holds \( \frac{1}{8} \) of the optimizer states, and \( \frac{1}{8} \) of the gradients, we should expect peak memory to be around:

$$2.24 + \frac{2.24}{8} + \frac{4.48}{8} + 1.75 = 4.83GB$$

This theoretical estimate assumes that the gradients produced by each operation in the backward pass will be less than the bucket size. This is not the case, however, with BLOOM 560M, as its word embedding layer contains \( 250880 \times 1024 \approx 256M \) parameters. Thus, during the backward pass, each device must store the gradients for these parameters before exchanging them via Reduce. The storage of these gradients requires around 1.03GB, which accounts for much of the difference between the theoretical estimate and the measured peak memory.

Stage 3 – Parameter Partitioning

ZeRO Stage 3 builds on stage 2 by partitioning the model parameters as well as the optimizer states and gradients. When we train the BLOOM 560M model with ZeRO Stage 3, we will notice that the allocated memory after each training step has decreased from 3.48GB to 1.94GB compared to Stage 2. The peak GPU, however, increases to 7.48GB.

RANK=0 MEMSTATS Memory stats after training step: device=cuda:0 current alloc=1.9410GB (delta=0.0000GB max=7.4082GB) current cache=10.7129GB (delta=0.0000GB max=10.7129GB)

The reason for this increase in peak GPU memory utilization is unclear at the moment, but I have an open issue on the DeepSpeed GitHub repository to better understand this behavior: DeepSpeed/issues/3734

Communication Overhead

While ZeRO can reduce the total amount of memory required for training, it does introduce additional communication overhead, as optimizer states, gradients, and model parameters must be exchanged frequently between devices.

DeepSpeed allows us to measure the amount of time taken by communication operations through its Communication Logging settings. Communication logging is enabled by adding the comms_logger section to the deepspeed_config:

deepspeed_config = {
    "train_micro_batch_size_per_gpu": 1,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 5e-5
        }
    },
    "zero_optimization": {
        "stage": 0
    },
    "comms_logger": {
        "enabled": True,
        "verbose": False,
        "prof_all": True,
        "debug": False
    }
}

Now, when executing the training script, every communication operation (AllReduce, AllGather, Reduce Scatter, etc.) will be captured and measured. To view the results of these measurements, we call the deepspeed.comm.log_summary() function:

## After training loop
import deepspeed.comm as dist
dist.log_summary()
Comm. Op            Message Size        Count               Total Latency(ms)   Avg Latency(ms)     tput_avg (Gbps)     busbw_avg (Gbps)    
broadcast
                    4.0 KB              221                 20.58               0.09                0.36                0.36                
                    8.0 KB              1                   0.17                0.17                0.38                0.38                
                    16.0 KB             24                  2.55                0.10                1.25                1.25                
                    113.27 KB           1                   0.09                0.09                10.70               10.70               
                    2.0 MB              1                   0.44                0.44                38.01               38.01               
                    4.0 MB              97                  71.00               0.72                46.41               46.41               
                    16.0 MB             48                  155.34              3.23                41.67               41.67               
                    113.27 MB           1                   178.26              178.26              5.33                5.33                
all_reduce
                    1.24 GB             50                  22385.67            445.17              48.11               36.08               
log_summary_barrier
                    0B                  1                   69.95               69.95               0.00                0.00  

Using these logs, I measured the amount of communication overhead introduced by each stage of ZeRO. The results are shown below:

Multi-Node Training

So far, all of the examples we have seen demonstrated distributed training with multiple GPUs on a single node. DeepSpeed can be applied to multi-node training as well. To perform multi-node training DeepSpeed, we can use the same training script as before, but some additional setup is required to allow multiple nodes to communicate with each other. This section outlines the steps to perform multi-node training with DeepSpeed across multiple AWS EC2 instances.

Assume we have three EC2 instances that are able to reach each other on all ports. An example of how to create these instances using the AWS CDK is included in the GitHub repository: distributed-training-and-deepspeed/aws/multi_node_training_stack.py.

To simplify this example, I will use the names worker-1, worker-2, and worker-3 instead of the actual IP addresses of the three GPU instances, and I will use localhost to refer to my local machine. If you would like to follow along using the worker names, you can add them as aliases to your ~/.ssh/config file like so, replacing the HostNames with the actual IP addresses of your instances, and the IdentityFile with the path to your AWS keypair:

Host worker-1
  HostName ec2-34-222-121-132.us-west-2.compute.amazonaws.com
  IdentityFile /home/me/keypair.pem
  User ubuntu

Host worker-2
  HostName ec2-32-27-113-7.us-west-2.compute.amazonaws.com
  IdentityFile /home/me/keypair.pem
  User ubuntu

Host worker-3
  HostName ec2-52-88-92-190.us-west-2.compute.amazonaws.com
  IdentityFile /home/me/keypair.pem
  User ubuntu

To run a multi-node training job, each node needs to be able to connect to every other node via passwordless SSH. To enable this, we will create a new SSH keypair and distribute it to each of the nodes:

# on localhost

# create a new SSH keypair
mkdir -p multi-node-training-keys
ssh-keygen -t rsa -N "" -f ./multi-node-training-keys/id_rsa

# copy SSH keys to worker-1
scp ./multi-node-training-keys/id_rsa.pub worker-1:/home/ubuntu/.ssh/id_rsa.pub
scp ./multi-node-training-keys/id_rsa worker-1:/home/ubuntu/.ssh/id_rsa

# add SSH keys to authorized_keys on worker node
ssh worker-1 'cat /home/ubuntu/.ssh/id_rsa.pub >> /home/ubuntu/.ssh/authorized_keys'

# repeat above three commands for worker-2 and worker-3

Next, we need to install DeepSpeed and its dependencies on each worker. The GitHub repository for this page includes a requirements.txt file and a worker-prereqs.sh script that can be used to install all of the prerequisites. On each node, we will clone the GitHub repository and install the prerequisites:

# on localhost

# install requirements on worker-1
ssh worker-1 'rm -rf distributed-training-and-deepspeed && git clone https://github.com/gnovack/distributed-training-and-deepspeed.git && pip install -r distributed-training-and-deepspeed/requirements.txt && distributed-training-and-deepspeed/scripts/worker-prereqs.sh'

# repeat above command for worker-2 and worker-3

Finally, though not required, it is nice to have the SSH hostname aliases configured on each worker. We can do this by writing a /home/ubuntu/.ssh.config file on each worker:

# on localhost; assumes that you have already added the SSH hostname aliases to your ~/.ssh/config file

worker_ip_1=$(ssh -G worker-1 | awk '$1 == "hostname" { print $2 }')
worker_ip_2=$(ssh -G worker-2 | awk '$1 == "hostname" { print $2 }')
worker_ip_3=$(ssh -G worker-3 | awk '$1 == "hostname" { print $2 }')

ssh worker-1 'cat > /home/ubuntu/.ssh/config' << EOF
Host worker-1
    HostName $worker_ip_1
    StrictHostKeyChecking no

Host worker-2
    HostName $worker_ip_2
    StrictHostKeyChecking no

Host worker-3
    HostName $worker_ip_3
    StrictHostKeyChecking no
EOF

# repeat above command for worker-2 and worker-3

By setting StrictHostKeyChecking no we bypass the need to manually accept the SSH host key for each worker. This simplifies this example but is not recommended for production environments.

A full script that executes all of these commands for all three worker nodes is available at distributed-training-and-deepspeed/scripts/generate-keys.sh

Now that each node can connect to each other node via passwordless SSH, the next step is to create a hostfile that lists the IP address of each node and specifies the number of available GPUs on each node. Assuming we have aliases for each worker node in our ~/.ssh/config, we can use these instead of IP addresses in our hostfile:

worker-1 slots=1
worker-2 slots=1
worker-3 slots=1

With the hostfile created, we are ready to run a multi-node training job. We can start the job from any one of the worker nodes. For this example, we will start the job from worker-1 by copying the hostfile from localhost to worker-1, and then running the training script with the --hostfile argument:

# on localhost

# copy the hostfile to worker-1
scp hostfile worker-1:/home/ubuntu/distributed-training-and-deepspeed/hostfile

# get the IP address of worker-1 from the SSH config file
MASTER_ADDR=$(ssh -G worker-1 | awk '$1 == "hostname" { print $2 }')

# ssh into first worker node and launch training
ssh worker-1 'cd distributed-training-and-deepspeed && PATH="/home/ubuntu/.local/bin:$PATH" deepspeed --master_addr=$MASTER_ADDR --hostfile=./hostfile zero_dp_training.py --stage=2 --model_name=facebook/opt-125m'

Wrapping Up

This exploration of distributed training and DeepSpeed has taught me a lot about the nuances and challenges of distributed model training, and about how the Zero Redundancy Optimizer can reduce the memory requirement when training across multiple GPUs. There are several other optimizations available in DeepSpeed that were not covered in this post, which leaves me with a few topics for future examination:

  1. Offloading: DeepSpeed offers the ability to offload the optimizer states and model parameters to CPU memory and disk, which can further reduce memory consumption at the expense of training time.

  2. Mixture of Experts (MoE): MoE models are a class of models that use sparsely activated layers (i.e. layers that select a subset of their weights to be used during each forward pass) to scale up the total number of model parameters without increasing computational complexity. The DeepSpeed library includes sparsely activated PyTorch models layers that can be used to implement MoE models.

  3. Progressive Layer Dropping (PLD): PLD is a technique that speeds up transformer model training by allowing specific transformer layers to be switched on and off at different points in the training process.


References

[1] Sohoni, Nimit S., et al. “Low-memory neural network training: A technical report.” arXiv preprint arXiv:1904.10631 (2019).

[2] https://huggingface.co/docs/transformers/v4.20.1/en/perf_train_gpu_one#anatomy-of-models-memory

[3] Korthikanti, Vijay Anand, et al. “Reducing activation recomputation in large transformer models.” Proceedings of Machine Learning and Systems 5 (2023).

[4] Li, Shen, et al. “Pytorch distributed: Experiences on accelerating data parallel training.” arXiv preprint arXiv:2006.15704 (2020).

[5] Sergeev, Alexander, and Mike Del Balso. “Horovod: fast and easy distributed deep learning in TensorFlow.” arXiv preprint arXiv:1802.05799 (2018).

[6] Huang, Yanping, et al. “Gpipe: Efficient training of giant neural networks using pipeline parallelism.” Advances in neural information processing systems 32 (2019).

[7] https://huggingface.co/docs/transformers/v4.15.0/parallelism#naive-model-parallel-vertical-and-pipeline-parallel

[8] Rajbhandari, Samyam, et al. “Zero: Memory optimizations toward training trillion parameter models.” SC20: International Conference for High Performance Computing, Networking, Storage and Analysis. IEEE, 2020.

[9] Scao, Teven Le, et al. “Bloom: A 176b-parameter open-access multilingual language model.” arXiv preprint arXiv:2211.05100 (2022).

Appendix

Appendix A: Computing Transformer Memory Requirements

Unfortunately, the approach used above to compute the Activation Memory (namely, the injection of forward hooks) does not play nicely with the Transformer models from the HuggingFace Transformers library, as many of the operations in these models are not torch.nn.Module instances, and thus will not be captured by hooks.

To overcome this, I wrote a basic TransformerBlock class, based on the OPTDecoderLayer implementation from HuggingFace, but I have used torch.nn.Module instances for all operations to ensure that every intermediate output will be captured by the activation_counter_hook hook.

class MatMul(nn.Module):
    """
    PyTorch Module wrapper for torch.bmm.
    """
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return torch.bmm(x, y)


class TransformerBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.ffn_dim = config.ffn_dim
        self.scaling = self.head_dim**-0.5
        
        self.pre_attention_layer_norm = nn.LayerNorm(self.hidden_size)
        self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.enable_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.enable_bias)
        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.enable_bias)
        self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.enable_bias)

        self.compute_attention_weights = MatMul()
        self.attention_weights_softmax = nn.Softmax(dim=-1)
        self.attention_weights_dropout = nn.Dropout(config.dropout)
        self.compute_attentions = MatMul()

        self.attention_output_dropout = nn.Dropout(config.dropout)
        self.final_layer_norm = nn.LayerNorm(self.hidden_size)

        self.ffn_1 = nn.Linear(self.hidden_size, self.ffn_dim, bias=config.enable_bias)
        self.ffn_2 = nn.Linear(self.ffn_dim, self.hidden_size, bias=config.enable_bias)
        self.activation = nn.GELU()
        self.final_dropout = nn.Dropout(config.dropout)
  
    def forward(self, hidden_states):
      ...

The full implementation of this class is available at distributed-training-and-deepspeed/model/transformer.py. The script for computing transformer model memory requirements can be found here: distributed-training-and-deepspeed/estimate_transformer_memory.py