Mixture of Experts Pattern for Transformer Models
Categories:
Traditional Transformer^{[1]} models are denselyactivated, meaning that all of the model’s parameters are used during every forward pass. Any increase in the number of parameters thus comes at the cost of increased computational complexity and memory consumption during training and inference.
The number of parameters, \(N\), in a given transformer layer is defined by the model’s hidden size, \(h\), and (assuming that the attention head dimension times the number of heads is equal to the hidden size, and assuming that the feedforward dimension is 4 times the hidden size) can be computed as^{[2]}:
The figure below shows how these parameters are distributed across the different submodules of a typical transformer layer:
Computational complexity can be measured by the number of floating point operations (FLOPs) performed during a forward pass. We can estimate the number of FLOPs for a given transformer block as^{[3]}:
where \(s\) is the sequence length (the number of tokens in the input), and batch size is assumed to be 1 for simplicity.
Computational complexity, then, increases quadratically with hidden size, and linearly with the number of parameters, given a fixed sequence length:
It has been observed that increasing the number of parameters of transformerbased language models improves their capacity to learn^{[4] [5]}. These improvements, however, come at the cost of increased computational complexity, meaning slower and more expensive pretraining, finetuning, and inference.
This relationship poses a question that remains an important area of research in machine learning: How can we realize the benefits of increased model size (namely, the ability to perform more complex tasks), without incurring the concomitant increase in complexity and cost?
With this page, I will look at the Mixture of Experts (MoE) pattern, which has been applied to transformer models in an efforts to address this question.
Mixture of Experts
The idea of a Mixture of Experts (MoE) in machine learning predates the transformer architecture by decades. Jacobs et al. applied an MoE approach to the problem of vowel recognition in 1991^{[5]}.
Although the implementations of the Mixture of Experts pattern are varied, the basic idea is consistent: instead of having a single set of parameters that are used to compute the output for every input, we route each input through a specific subset of the model parameters. Each possible subset of parameters, then, can be thought of as an expert that is responsible for handling a specific type of input.
The picture below provides a crude illustration of this idea:
An illustration of a basic Mixture of Experts pattern, where EnglishtoSpanish translation is performed by expert 1, and EnglishtoJapanese translation is performed by expert 2.
Routing
One important detail of any MixtureofExperts implementation is the Routing Algorithm, which decides which expert (or set of experts) to use for a given input.
Soft Selection
Some early work applied a soft selection routing algorithm (also referred to as a continuous mixture of experts) to deep learning models^{[6]}. With this approach, every expert is used for every input, but the contribution of each expert is weighted by a gating function, a learned function that computes a weight, or importance, for each expert such that the weights of all experts sum to 1.
Because every expert is used for every input, this approach still results in a densely activated model, and thus does not address the problem of increased computational complexity. This work, however, provided an important stepping stone for more advanced implementations of MoE.
The code below demonstrates a Continuous Mixture of Experts Layer like that described in [6]. The experts in this case are individual linear layers, each with its own set of parameters, and the gating function is itself a twolayer feedforward network. For each input, the gating function computes a weight for each expert, and the output of the Continuous Mixture of Experts Layer is the weighted sum of the individual expert outputs:
class ContinuousMixtureOfExpertsLayer(torch.nn.Module):
def __init__(self, in_features: int, out_features: int,
num_experts: int, gating_network_hidden_units: int):
super().__init__()
self.experts = torch.nn.ModuleList([
torch.nn.Linear(in_features, out_features)
for _ in range(num_experts)
])
self.gating_function_input = torch.nn.Linear(
in_features,
gating_network_hidden_units
)
self.gating_function_output = torch.nn.Linear(
gating_network_hidden_units,
num_experts
)
def forward(self, x: torch.Tensor) > torch.Tensor:
# compute the output from each expert
expert_outputs = []
for expert in self.experts:
expert_output = expert(x)
expert_output = torch.nn.functional.relu(expert_output)
expert_outputs.append(expert_output)
expert_outputs = torch.stack(expert_outputs, dim=1)
# execute the gating function to compute weights for each expert
gating_function_hidden_state = self.gating_function_input(x)
gating_function_hidden_state = torch.nn.functional.relu(gating_function_hidden_state)
gating_function_output = self.gating_function_output(gating_function_hidden_state)
expert_weights = torch.nn.functional.softmax(gating_function_output, dim=1)
# compute the output of the layer as the weighted sum of the expert outputs
return torch.sum(
expert_outputs * expert_weights.unsqueeze(1),
dim=1
)
During training, each expert becomes specialized at handling a specific type of input, and the gating function learns to route each input to the expert that is best suited to handle it.
To get a better idea of how different experts specialize on certain types of input, I trained a Continuous MixtureofExperts model on the MNIST dataset. For each of the four experts in the trained model’s hidden layer, the image below shows the test examples that were given the 9 highest expert weights:
Highest weighted examples for each expert in a continuous mixtureofexperts model.
Hard Selection
Later work expanded on the MoE pattern by applying hard selection routing algorithms, where only a subset of the experts are used for any given input, marking a shift from densely activated to sparse models. Shazeer, Noam, et al. introduced the SparselyGated MixtureofExperts Layer and applied hard selection to natural language processing on a large scale, using a gating function similar to that used by the continuous mixture of experts layer described above.
The gating function consists of a single linear layer that computes a weight for each expert, with random noise added to the weights to help with load balancing (the load balancing problem is discussed further below). Before the \(Softmax\) function is applied to the expert weights, a \(KeepTopK\) operation is applied, which sets the weights of all but the top \(k\) experts to \( \infty \). This ensures that only the top \(k\) experts will have a weight greater than 0 after \(Softmax\) is applied.
An illustration of expert weight computation for a mixture of 4 experts. Expert weights are computed separately for each row in the input. In a language modeling context, each row would represent a token in the input sequence. Thus, each token may be routed to different experts.
The code below includes a basic implementation of the gating function for the SparselyGated MixtureofExperts Layer:
class SparselyGatedMixtureOfExpertsLayer(torch.nn.Module):
def __init__(self, in_features: int, num_experts: int):
super().__init__()
self.w_g = torch.nn.Parameter(torch.zeros(in_features, num_experts))
self.w_noise = torch.nn.Parameter(torch.zeros(in_features, num_experts))
def tunable_noise(self, x):
"""Compute tunable noise to add to the gating function output."""
weighted_noise = torch.matmul(x, self.w_noise)
return torch.randn_like(weighted_noise) * torch.nn.functional.softplus(weighted_noise)
def compute_expert_weights(self, x, k):
noise = self.tunable_noise(x)
expert_weights = torch.matmul(x, self.w_g) + noise
print("Raw Expert Weights: \n", expert_weights)
# get indices of the bottom k weights
_, bottom_k_indices = torch.topk(expert_weights, k=k, dim=1, largest=False)
print("Bottom K Expert Indices: \n", bottom_k_indices)
# set the bottom k weights to inf
expert_weights = expert_weights.scatter(dim=1, index=bottom_k_indices, value=float("inf"))
# apply softmax to the expert weights
expert_weights = torch.nn.functional.softmax(expert_weights, dim=1)
print("Final Expert Weights: \n", expert_weights)
return expert_weights
To see this in action, we can execute the gating function using the same input shape and number of experts as in the image above:
layer = SparselyGatedMixtureOfExpertsLayer(in_features=6, num_experts=4)
x = torch.randn(3, 6)
layer.compute_expert_weights(x, k=2)
Raw Expert Weights:
tensor([[ 0.3931, 0.8921, 0.9925, 1.1449],
[ 0.3835, 0.3427, 0.0513, 0.2176],
[0.3423, 0.4838, 0.0443, 1.7873]], grad_fn=<AddBackward0>)
Bottom K Expert Indices:
tensor([[3, 2],
[3, 2],
[0, 2]])
Final Expert Weights:
tensor([[0.3778, 0.6222, 0.0000, 0.0000],
[0.5102, 0.4898, 0.0000, 0.0000],
[0.0000, 0.2136, 0.0000, 0.7864]], grad_fn=<SoftmaxBackward0>)
The SparselyGated MixtureofExperts Layer was originally applied to language modeling with LSTM models [8], but this same idea of hard selection routing based on the top k expert weights was later successfully used in transformer models with the introduction of Switch Transformers (Fedus, W., Zoph, B. and Shazeer, N., 2022.)
Switch Transformer
The Switch Transformer is a modification of the traditional denselyactivated transformer model. Specifically, Switch Transformers replace the fullyconnected feedforward network in the transformer block with a sparselyactivated mixture of experts layer like the one discussed above. In this case, however, each expert is itself a twolayer feedforward network.
The figure below illustrates the difference between a typical “vanilla” transformer block, and a Switch Transformer block:
For a Switch Transformer block with hidden size, \( h \) and number of experts, \( E \), the total number of parameters is given by:
Thus when there is only 1 expert, we have
which is virtually the same as the number of parameters in a denselyactivated transformer. The slight difference is because the Switch Transformer does not include a bias term in any of its linear layers, and because the weights of the router introduce an additional \( Eh \) parameters.
To illustrate this, we can use the PyTorch implementation of the Switch Transformer included in the opensource HuggingFace transformers
library [10]. The code below shows how to instantiate a Switch Transformer layer using the SwitchTransformersBlock
class, then counts the number of parameters for different values of \( h \) and \( E \):
import itertools
from tabulate import tabulate
from transformers.models.switch_transformers.modeling_switch_transformers import SwitchTransformersBlock, SwitchTransformersConfig
from utils import human_readable
if __name__ == "__main__":
hidden_sizes = [768, 1024, 2048]
number_of_experts = [1, 2, 4]
table_data = []
for hidden_size, num_experts in itertools.product(hidden_sizes, number_of_experts):
config = SwitchTransformersConfig(
d_model=hidden_size,
num_heads=hidden_size // 64,
num_layers=1,
d_ff=4*hidden_size,
num_experts=num_experts
)
layer = SwitchTransformersBlock(config, is_sparse=True)
parameter_count = sum([p.numel() for p in layer.parameters()])
table_data.append((hidden_size, num_experts, human_readable(parameter_count)))
print(tabulate(
table_data,
headers=["Hidden Size", "Number of Experts", "Parameter Count"],
tablefmt="fancy_grid"
))
╒═══════════════╤═════════════════════╤═══════════════════╕
│ Hidden Size │ Number of Experts │ Parameter Count │
╞═══════════════╪═════════════════════╪═══════════════════╡
│ 768 │ 1 │ 7.08M │
├───────────────┼─────────────────────┼───────────────────┤
│ 768 │ 2 │ 11.80M │
├───────────────┼─────────────────────┼───────────────────┤
│ 768 │ 4 │ 21.24M │
├───────────────┼─────────────────────┼───────────────────┤
│ 1024 │ 1 │ 12.59M │
├───────────────┼─────────────────────┼───────────────────┤
│ 1024 │ 2 │ 20.98M │
├───────────────┼─────────────────────┼───────────────────┤
│ 1024 │ 4 │ 37.75M │
├───────────────┼─────────────────────┼───────────────────┤
│ 2048 │ 1 │ 50.34M │
├───────────────┼─────────────────────┼───────────────────┤
│ 2048 │ 2 │ 83.89M │
├───────────────┼─────────────────────┼───────────────────┤
│ 2048 │ 4 │ 151.01M │
╘═══════════════╧═════════════════════╧═══════════════════╛
The computational complexity of a Switch Transformer block, as measured by the number of forward pass FLOPs, can be estimated as:
Notice the only difference between vanilla and switch transformer computational complexity is the additional \( 2Esh \) FLOPs introduced by the router. What this means is that by increasing the number of experts, \(E\), we can significantly increase the number of model parameters while incurring a negligible increase in computational complexity.
The chart below plots model parameters against computational complexity for different values of \(E\)
Given the same number of parameters, a Switch Transformer with more experts is more computationally efficient than a denselyactivated transformer. For example, a Switch Transformer with 30M parameters and 8 experts performs ~13.62B FLOPs during a forward pass, while a denselyactivated transformer of the same size performs ~68.06B FLOPs.
Load Balancing
One of the challenges of training mixtureofexperts models is the need to balance responsibility between experts. If all of the inputs are routed to only one or two of the available experts, then the remaining experts are essentially useless, as their parameters are never used when predicting the output.
Without some additional safeguards, this expert imbalance can occur naturally during training. Eigen, D., Ranzato, M.A. and Sutskever describe this well when they write:
The experts at each layer that perform best for the first few examples end up overpowering the remaining experts. This happens because the first examples increase the gating weights of these experts, which in turn causes them to be selected with high gating weights more frequently. This causes them to train more, and their gating weights to increase again, ad infinitum.
When training the Continuous Mixture of Experts model for MNIST digit classification, I encountered this same phenomenon. The figure below shows the average expert weights observed during training:
To combat this imbalance, various Load Balancing techniques have been implemented with the aim of ensuring that all experts are given roughly equal responsibility during training. Eigen, D., Ranzato, M.A. and Sutskever address this issue by “dropping out” any experts that begin to overpower the others.
A running total of the gating weights for each expert is maintained and updated on each training step. If the sum of gating weights for any one expert increases too far above the average sum of gating weights for all experts, then that expert is given a gating weight of 0 for the training step.
The code below shows a modified implemention of the ContinuousMixtureOfExpertsLayer
which includes this load balancing technique:
def forward(self, x: torch.Tensor) > torch.Tensor:
# compute the output from each expert
expert_outputs = []
for expert in self.experts:
expert_output = expert(x)
expert_output = torch.nn.functional.relu(expert_output)
expert_outputs.append(expert_output)
expert_outputs = torch.stack(expert_outputs, dim=1)
# execute the gating function to compute weights for each expert
gating_function_hidden_state = self.gating_function_input(x)
gating_function_hidden_state = torch.nn.functional.relu(gating_function_hidden_state)
gating_function_output = self.gating_function_output(gating_function_hidden_state)
expert_weights = torch.nn.functional.softmax(gating_function_output, dim=1)
step_total_assignment = torch.sum(expert_weights, dim=0)
self.running_total_assignment += step_total_assignment
if self.balance:
mean_total_assignment = torch.mean(self.running_total_assignment)
expert_mask = self.running_total_assignment.clone()
expert_mask = expert_mask.detach()
expert_mask[self.running_total_assignment  mean_total_assignment > self.threshold] = 0
expert_mask[self.running_total_assignment  mean_total_assignment <= self.threshold] = 1
expert_weights = expert_weights * expert_mask
expert_weights = torch.nn.functional.softmax(expert_weights, dim=1)
# compute the output of the layer as the weighted sum of the expert outputs
return torch.sum(
expert_outputs * expert_weights.unsqueeze(1),
dim=1
)
After adding the load balancing constraint, the gating weights were more evenly distributed throughout the training process:
Load Balancing Loss
Others have addressed the load balancing problem by introducing additional loss terms to the training objective. The Switch Transformer models were trained with an auxiliary load balancing loss, which is computed as the dot product of two vectors, \(f\) and \(P\), both of length \(E\), where \(f\) is the proportion of tokens routed to each expert, and \(P\) is the proportion of gating weights corresponding to each expert.
For example, consider a batch of three input tokens with the following gating weights corresponding to 4 experts:
# gating weights of shape [sequence_length, num_experts]
gating_weights = torch.tensor([
[0.25, 0.50, 0.00, 0.25],
[0.70, 0.10, 0.10, 0.10],
[0.30, 0.40, 0.20, 0.10],
])
Using the top1 routing algorithm of the Switch Transformer, we know that the first token will be routed to the expert at index 1, the second token will be routed to the expert at index 0, and the third token will be routed to the expert at index 1. We can represent these routing decisions as a onehot tensor:
# onehot routing decisions of shape [sequence_length, num_experts]
# if index i,j is 1, then token i was routed to expert j
routing_decisions = torch.tensor([
[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
])
We can then compute the vectors \(f\) and \(P\) as follows:
# The value of f is [0.33, 0.67, 0.00, 0.00],
# because 1/3 of the tokens were routed to expert 0,
# 2/3 of the tokens were routed to expert 1,
# and 0 tokens were routed to experts 2 and 3.
f = routing_decisions.mean(axis=0, dtype=torch.float32)
# The value of P is [0.41, 0.33, 0.10, 0.15]. The sum of the
# gating weights for all tokens and experts is 3.0, so to get the
# proportion of gating weights for expert 0, we compute
# (0.25 + 0.70 + 0.30) / 3.0 = 0.41, an so on for each other expert.
P = gating_weights.mean(axis=0, dtype=torch.float32)
To calculate the auxiliary load balancing loss, \(L_{aux}\), we compute the dot product of \(f\) and \(P\), then multiply the result by \(E\) so that the loss does not depend on the absolute number of experts:
# auxiliary load balancing loss is 1.44
auxillary_loss = torch.dot(f, P) * num_experts
The full loss function used to train the model combines the auxiliary load balancing loss, \(L_{aux}\), with the standard crossentropy loss, \(L_{CE}\). An additional hyperparameter, \(\alpha\), is used to control the relative importance of the auxiliary loss:
Expert Capacity
Some MoE implementations^{[9] [11] } use a fixed Expert Capacity to limit the maximum number of tokens that can be routed to any one expert. Any tokens routed to an expert that it at capacity are considered overflow tokens, and are not routed to any expert. The embedding representation of an overflow token is passed on to the next layer unchanged.
Balancing by Design
By redesigning the routing algorithm, other implementations of MoE have been able to achieve balance between experts without introducing auxiliary loss terms or setting a limit on expert capacity. Lewis, Mike, et al. introduced the Balanced Assignment of Experts (BASE) Layer, which routes to each expert an equal number of input tokens, rather than simply choosing the expert with the highest gating weight for each individual token.
The BASE Layer performs routing using an auction algorithm^{[13]}, where each expert bids for the most optimal tokens (the most optimal token for any expert being that which has the highest gating weight for that expert). At the end of the auction, each expert is left an equal number of tokens. For example, if there are 16 experts and 1024 tokens in the input sequence, each expert will be assigned \(\frac{1024}{16} = 64\) tokens.
To understand how the auction algorithm works, consider a scenario with 3 experts, 3 input tokens, and a 3x3 matrix of gating weights. In the simplest case, each expert will bid on a different token and win that token:
An easy assignment problem where each expert bids on a different token. In the gating weights matrix shown above, rows represent tokens and columns represent experts (i.e. the first row contains the gating weights for the first token corresponding to each expert, the second row contains the gating weights for the second token corresponding to each expert, etc.)
Given a large enough number of experts and tokens, however, it is extremely unlikely that no two experts will bid on the same token. It is almost certain that multiple experts will bid on the same token, in which case the auction algorithm must decide which expert wins that token:
Much like in a real auction, conflicts like this are resolved by giving the token to the highest bidder. The amount bid by each expert is determined by the difference between the highest and secondhighest gating weights for that expert. In the illustration above, Expert 0 will bid \(1.24  0.71 = 0.53\) for Token 0, Expert 1 will bid \(0.98  0.48 = 0.50\), so Token 0 will be assigned to Expert 0.
If there are any tokens that have not been assigned after the first round of bidding (as is the case with Token 1 in the example), then subsequent rounds of bidding are performed until all tokens have been assigned.
Hash Layers
Roller, Stephen, Sainbayar Sukhbaatar, and Jason Weston explore the use of Hash Layers to achieve balanced routing without the need for any trainable gating weights or assignment algorithms. For each input token, a hash function is used to map that token’s ID (i.e. the index of the token in the model’s vocabulary) to one of the available experts.
For example, one simple hashing function might be \(h(x) = x \bmod E\), where \(x\) is the token ID and \(E\) is the number of experts. This function maps each token ID to one of the experts, and the mapping is deterministic (i.e. the same token ID will always be mapped to the same expert).
num_experts = 8
input_sequence = "This is a hash layer example."
# Tokenize input sequence
input_ids = tokenizer.encode(input_sequence, return_tensors='pt')
print("Token ids:", input_ids)
# Map each token to an expert using `x % num_experts`
expert_assignment = torch.fmod(input_ids, num_experts)
print("Expert assignment:", expert_assignment)
Token ids: tensor([[ 1212, 318, 257, 12234, 7679, 1672, 13]])
Expert assignment: tensor([[4, 6, 1, 2, 7, 0, 5]])
Other more complex hash functions have been explored as well, for example by applying kmeans clustering to the token embeddings produced by a separate pretrained transformer model, or by precomputing a hash table mapping token IDs to experts based on token frequency in the training data, thus ensuring a more balanced assignment of tokens to experts.
Wrapping Up
This look at the MixtureofExperts patterns and its applications in the domain of language modeling with transformers has helped me understand many of the details and considerations of implementing an MoE model. It is worth noting a few important topics that were not covered here (among many others of which I am likely unaware):

Performance of MoE models vs. their dense counterparts: We saw how it is possible to increase the number of parameters in an MoE model with incurring any additional computational cost; however, it does not necessarily follow that the performance of an MoE model will match that of a dense model with the same parameter count. Clark, Aidan, et al. explore this topic in more detail, and find that the introduction of MoE routing generally improves performance on the language modeling task. The propensity for MoE models to adapt to downstream tasks via finetuning, however, is less clear.

Sparsity Beyond MoE: The MixtureofExperts pattern is not the only way to introduce sparsity into transformer models. Many other Sparse Attention mechanisms, sparse feedforward networks, and model pruning techniques have been implemented. I leave the exploration of these topics for another day.
References
[10] https://huggingface.co/docs/transformers/v4.32.0/en/model_doc/switch_transformers
Appendix
Parameter Count and FLOPs
For a denselyactivated transformer layer, the formulas below were given for parameter count, \(N\) and computational complexity, \(C\):
To validate these formulas, I wrote a script to compute the parameter count of an instance of OPTDecoderLayer
from the transformers
library, then used the deepspeed
library’s FLOPs Profiler to measure the number of floatingpoint operations in a single forward pass.
The full script is available here: parameters_vs_flops.py, and the output is shown below for several different values of the hidden size, \(h\):
╒═══════════════╤══════════════════════════╤═══════════════════╤═════════════════════════════╤══════════════════════╕
│ Hidden Size │ Parameter Count (est.) │ Parameter Count │ Forward Pass FLOPs (est.) │ Forward Pass FLOPs │
╞═══════════════╪══════════════════════════╪═══════════════════╪═════════════════════════════╪══════════════════════╡
│ 256 │ 789.76K │ 789.76K │ 2.68B │ 2.69B │
├───────────────┼──────────────────────────┼───────────────────┼─────────────────────────────┼──────────────────────┤
│ 512 │ 3.15M │ 3.15M │ 8.59B │ 8.61B │
├───────────────┼──────────────────────────┼───────────────────┼─────────────────────────────┼──────────────────────┤
│ 768 │ 7.09M │ 7.09M │ 17.72B │ 17.74B │
├───────────────┼──────────────────────────┼───────────────────┼─────────────────────────────┼──────────────────────┤
│ 1024 │ 12.60M │ 12.60M │ 30.06B │ 30.10B │
├───────────────┼──────────────────────────┼───────────────────┼─────────────────────────────┼──────────────────────┤
│ 2048 │ 50.36M │ 50.36M │ 111.67B │ 111.73B │
├───────────────┼──────────────────────────┼───────────────────┼─────────────────────────────┼──────────────────────┤
│ 4096 │ 201.38M │ 201.38M │ 429.50B │ 429.62B │
├───────────────┼──────────────────────────┼───────────────────┼─────────────────────────────┼──────────────────────┤
│ 8192 │ 805.41M │ 805.41M │ 1.68T │ 1.68T │
├───────────────┼──────────────────────────┼───────────────────┼─────────────────────────────┼──────────────────────┤
│ 16384 │ 3.22B │ 3.22B │ 6.67T │ 6.67T │
╘═══════════════╧══════════════════════════╧═══════════════════╧═════════════════════════════╧══════════════════════╛
The estimated parameter count and FLOPs (i.e. those computed using the formulas) closely match the actual parameter count and FLOPs measured using the transformers
and deepspeed
libraries.