Understanding State Space Models

Introduction

This page is about State Space Models (SSMs) and their applications in the world of deep learning. Sequence-to-sequence models based on SSMs like S4[6], H3[7], Mamba[8], and others offer a promising alternative to the widely used Transformer[1] architecture.

The first thing that can be confusing when diving into the literature around SSMs is that, unlike the Transformer, which refers to a (mostly) specific neural network architecture, the concept of a State Space Model is old and broadly used in many domains outside of machine learning.

State Space Models are not neural networks; instead, we can use the concept of a State Space Model to build neural networks, particularly those trained to operate on sequential data like text, audio, video, etc.

This page does not include a detailed overview of the Transformer architecture; however, because much of the research into SSMs has been motivated by the need to overcome some of the limitations of Transformers, some concepts and terminology are introduced here.

Transformers

Transformers are sequence-to-sequence models that accept as input a sequence of inputs and produce as output another sequence of the same length.

In the text generation use case, we treat the last element of the output sequence as the continuation of the input sequence, and are thus able to perform autoregressive, one-token-at-a-time text generation by repeatedly executing the model.

A typical transformer model consists of several identically structured transformer blocks stacked sequentially, such that the output of one block becomes the input of the next. Each transformer block consists of a self-attention operation, a token-wise multi-layer perceptron (MLP), residual connections, and normalization:

The input to each block is a set of input embeddings of shape sequence_length × hidden_size. Throughout this page, we will refer to these two dimensions respectively as the Time Dimension and the Channel Dimension, borrowing on the terminology used elsewhere in the literature on transformers [2] [3].

When we take a step back from the nuance of the transformer block’s anatomy, and think from the perspective of the time and channel dimensions, we can broadly categorize any operation into one of the following categories:

  1. Operations that mix information along the channel dimension.

  2. Operations that mix information along the time dimension.

  3. Operations that do both 1 and 2.

The MLPs in transformers are an example of category 1, as each individual token embedding is processed independently.

Self-Attention is an example of category 3, because information is mixed along both dimensions. For example, the query projection in self-attention mixes along the channel dimension by projecting each embedding into a query vector, while the actual computation of attention scores mixes along the time dimension.

Thus, very abstractly, we can consider a transformer block as consisting of a time-mixing operation followed by a channel-mixing operation:

This abstract view can be helpful when reading about alternative language model architectures (i.e. those that aim to improve on the transformer), as it forms a basic template for any language model architecture. The actual implementations of these time-mixing and channel-mixing operations are what set each architecture apart.

State Space Models (SSMs)

Before we discuss the use of State Space Models in machine learning and language modeling, it is helpful to understand State Space Models more generally, as the notation and terminology will come up frequently in literature on the topic.

Very generally, a State Space Model is a mathematical representation used to describe a system that accepts inputs, produces outputs, and maintains some internal state. The output of the system is dependent on both the input and the current state.

Importantly, the state of the system changes, or evolves, as inputs are received. The rate at which the state of the system changes is typically described by the equation:

$$\begin{aligned} x'(t) &= \mathring{A}x(t) + \mathring{B}u(t) \end{aligned}$$

where \( t \) represents time, and the two functions of \( t \) represent the following:

$$\begin{aligned} u(t) &: \text{the input to the system at time } t \\ x(t) &: \text{the internal state of the system at time } t \end{aligned}$$

\( x^\prime(t) \) in this equation is the derivative of \( x(t) \), the rate of change of the system’s state.

Put another way, the rate at which the state of the system changes at any given time \( t \) is dependent on both the current state, \( x(t) \), and the current input, \( u(t) \).

The output of the system is likewise defined as a function of \(t\):

$$\begin{aligned} y(t) &= \mathring{C}x(t) + \mathring{D}u(t) \end{aligned}$$

Thus, the output of the system at any time \( t \) depends on both the current state, \( x(t) \), and the current input, \( u(t) \).

The four uppercase variables, \(\mathring{A}\), \(\mathring{B}\), \(\mathring{C}\), and \(\mathring{D}\) are often called the State Space Model Parameters. The Wikipedia page on State Space Models uses these following names, which provide some hint at the function of each parameter in the model:

$$\begin{aligned} \mathring{A} &: \text{the state transition matrix} \\ \mathring{B} &: \text{the input matrix} \\ \mathring{C} &: \text{the output matrix} \\ \mathring{D} &: \text{the feedthrough matrix} \end{aligned}$$

More often, however, these will simply be referred to as \(A\), \(B\), \(C\), and \(D\) in the literature (my use of the ring accent over each parameter is explained in the next section).

Continuous and Discrete Time

In most physical systems that can be described using a state space representation (i.e. using the two equations above), time is continuous, meaning that \(t\) can take on any real value.

For example, if we were to plot the output function of a continuous time State Space Machine, we would see a single continuous line like this:

When applying state space models to the realm of natural language, however, we operate on discrete units of time. So the output of the model plotted against time will look more like this:

In the discrete time domain, we can describe the state of a state space model using the following equation. Here, instead of computing the derivative, or instantaneous rate of the change of the state, we instead compute the state at each time step \( t\) based on the state at the previous time step \(t-1\):

$$\begin{aligned} x_{t} &= Ax_{t-1} + Bu_{t} \end{aligned}$$

At each discrete time step \(t\), the hidden state \(x_{t}\) contains information about all inputs up to and including the input received at time \(t\). An additional parameter \( \Delta \) is used to control the interval between each time step.

Notice that when describing a discrete time SSM, we use \( A \) and \( B \) in place of \( \mathring{A} \) and \( \mathring{B} \).

A quick note on this notation: elsewhere in the literature on SSMs you will likely see the continuous time parameters referred to as \(A\) and \(B\), while \(\bar{A}\) and \(\bar{B}\) are often used to indicate the discrete time parameters.

On this page, however, I will follow the convention from the Mamba 2 paper[4], where \(A\) and \(B\) refer to discrete time parameters, as these are what we are most interested in when building deep learning models based on SSMs.

Discretization

The conversion from the continuous time represention to the discrete time representation (i.e. from (\( \mathring{A} \), \( \mathring{B} \)) to (\(A\) \(B\))) is termed Discretization. Discretization is a well-studied problem in world of control systems, and a number of different methods exist to compute discrete time approximations of continuous time systems.

Two popular discretization functions used for SSMs in machine learning are the Bilinear Transform and Zero-order Hold.

The formulas for both of these methods are shown below, but try not to get bogged down in the details here; the important thing to remember is that these are simply ways to approximate the original parameters \(\mathring{A}\) and \(\mathring{B}\) in the discrete time domain:

Bilinear Transform

$$\begin{aligned} A &= (I - \frac{\Delta}{2} \mathring{A})^{-1}(I + \frac{\Delta}{2} \mathring{A}) \\ \\ B &= (I - \frac{\Delta}{2} \mathring{A})^{-1}(\Delta \cdot \mathring{B}) \end{aligned}$$

Zero-order Hold

$$\begin{aligned} A &= e^{\Delta \cdot \mathring{A} } \\ \\ B &= (\mathring{A})^{-1}(A - I) \cdot \Delta \mathring{B} \end{aligned}$$

Anyone who is curious to understand more about the details and derivations of these methods should check out this great video series on discrete control: Discrete Control Playlist by Brian Douglas.

SSMs as Recurrent Neural Networks

Discrete-Time SSMs share a lot of similarities with Recurrent Neural Networks (RNNs), a popular model architecture for NLP that preceded the transformer.

At each time step, the following occurs:

  1. The model accepts an input along with the state computed in the previous step
  2. The state is updated according to the equation \( x_{t} = Ax_{t-1} + Bu_{t} \)
  3. The output for that time step is computed using the equation \( y_{t} = Cx_{t} + Du_{t} \)
  4. The updated state is passed on to the next time step.

When the input to our SSM is a fixed-length sequence, as is often the case with many modalities like text and image, then we can think of the SSM as a simple for loop over the elements in the input sequence:

x = init_state()

for u in input:
    x = (A @ x) + (B @ u)
    y = (C @ x) + (D @ u)
    yield y

SSMs as Convolutional Neural Networks

Perhaps less intuitively, SSMs can also be seen as Convolutional Neural Networks (CNNs). To illustrate this, let’s consider again the equation we use to compute the state \(x_{t}\) at each timestep:

$$\begin{aligned} x_{t} &= Ax_{t-1} + Bu_{t} \end{aligned}$$

If we assume that the initial state is set to 0 before any input is received (in other words, \(x_{-1} = 0\)), then we can compute what the state will be at the first few timesteps like so:

$$\begin{aligned} x_{0} = Ax_{-1} + Bu_{0} &= Bu_{0} \\ x_{1} = Ax_{0} + Bu_{1} &= ABu_{0} + Bu_{1} \\ x_{2} = Ax_{1} + Bu_{2} &= A^{2}Bu_{0} + ABu_{1} + Bu_{2} \\ x_{3} = Ax_{2} + Bu_{3} &= A^{3}Bu_{0} + A^{2}Bu_{1} + ABu_{2} + Bu_{3} \end{aligned}$$

From here we see a pattern beginning to emerge. Namely, that at any given timestep \(t\), we can compute the state \(x_{t}\) like so:

$$x_{t} = A^{t}Bu_{0} + A^{t-1}Bu_{1} + A^{t-2}Bu_{2} + ... + Bu_{t}$$

Which, as we will see, can be restated as a convolution. Let’s define \(K\) to be the sequence of coefficients in the equation above (i.e. everything but the \(u_{k}\) terms). \(K\) will thus look like:

$$K_{t} = (A^{t}B, A^{t-1}B, A^{t-2}B,...,B)$$

If we now treat \(K\) as the kernel of a convolution, the hidden state at every timestep can be computed as:

$$x = K * u$$

For those more familiar with the use of convolutions in the context of image processing, this may seem a bit confusing. When working with images, we usually visualize a convolution operation as “sliding” a 2-dimensional filter over an input image and, at each step, computing an element in the output as the weighted some of the input values and the filter weights:

Convolutions in 1-dimension operate the same way: we just have to imagine the inputs, filter, and, consequently, the outputs flattened into 1-dimensional lists.

From here, we can imagine subbing in \(K\) as the filter and \(u\) as the input to see that this convolution gets us back to the formula for \(x_{t}\) defined earler:

Recall that the input to a convolution can be padded with 0s such that some elements in the output are computed when the filter only partially overlaps with the input. In the case of our convolution \( x = K * u \), this is how all the values of \(x\) before \(x_{t}\) are computed:

It is worth noting that, for this convolutional approach to work, the SSM parameters \(A\) and \(B\) must not depend on the current timestep \(t\). Or, to use the jargon of SSMs, we must be using a Time-Invariant State Space Model.

One Note on the Convolution Kernel

The definition of the convolution \(K\) shown above allows us to compute the state at every time step all at once using \(x = K * u\); however, we are ultimately interested in computing the SSM output at each time step rather than just the state. The output of the SSM is computed as:

$$\begin{aligned} y &= Cx + Du \\ &= C(K * u) + Du \\ &= (CK * u) + Du \end{aligned}$$

To simplify things going forward, we include \(C\) as part of the convolution kernel \(K\) like so:

$$K_{t} = (CA^{t}B, CA^{t-1}B, CA^{t-2}B,...,CB)$$

HiPPO - Selecting the \(A\) Matrix

The structure of the matrix \(A\) is of particular interest when building SSMs for sequence modeling, as this matrix is ultimately what determines which parts of the previous state are passed on to the current state.

The state of the model is a fixed-length vector as shown above, whose length we will refer to as the State Size, or \(N\). If the \(A\) matrix is ineffective, important information can be lost as the state transitions between time steps.

As an extreme example, imagine that the \( A \) matrix is filled entirely with zeros. In this case, no information from the previous state vector \(x_{t-1}\) would be preserved in the next state vector \(x_{t} \).

A special class of matrices that are effective at preserving state between discrete time steps was derived by Gu, Albert, et al. [5] using a framework they termed High-Order Polynomial Projection Operators (HiPPO).

Many language models based on SSMs initialize the \(A\) matrix to one of the special matrices derived using the HiPPO framework. For example, the following HiPPO Matrix (termed “HiPPO-LegS”) is used by the LSSL and S4 models:

$$A_{nk} = - \begin{cases} (2n + 1)^{1/2}(2k + 1)^{1/2} & \text{if } n > k \\ n + 1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases}$$

Converting the above into simple PyTorch code, we get the following function to create the HiPPO Matrix:

import torch

def hippo_matrix(state_size):
    
    # Generate a square matrix with values (1, state_size+1) on the diagonal
    A = torch.diag(torch.arange(1, state_size+1, dtype=torch.float))

    # Function to generate the elements below the diagonal based on 
    # the row and column index
    def lower_triangle_elements(n, k):
        return (2*n + 1)**0.5 * (2*k + 1)**0.5

    for i in range(1, state_size):
        for j in range(i):
            A[i, j] = lower_triangle_elements(i, j)
    
    return -1*A

hippo_matrix(state_size=5)
tensor([[-1.0000, -0.0000, -0.0000, -0.0000, -0.0000],
        [-1.7321, -2.0000, -0.0000, -0.0000, -0.0000],
        [-2.2361, -3.8730, -3.0000, -0.0000, -0.0000],
        [-2.6458, -4.5826, -5.9161, -4.0000, -0.0000],
        [-3.0000, -5.1962, -6.7082, -7.9373, -5.0000]])

A Brief History of Language Modeling SSMs

With the background material on State Space Models out of the way, we can now start to look at some of the major milestones in the application of SSMs to deep learning.

Linear State Space Layers (LSSL)

The work on the Linear State Space Layer[9] introduced much of the terminology and foundations that would be built upon in later work on SSMs as sequence-to-sequence models.

While it is possible to trace the origin of several of the concepts applied by the LSSL back to prior sources, I chose to begin here as there seems to be a direct line of descent from the LSSL to later more popular SSM-based models like H3 and Mamba.

The LSSL is perhaps the most “vanilla” version of the SSM-based sequence model, in that it does not vary much from the description of a State Space Model introduced in the previous section. i.e.

$$\begin{aligned} x_{t} &= Ax_{t-1} + Bu_{t} \\ y_{t} &= Cx_{t} + Du_{t} \\ \end{aligned}$$

It relies on both the convolutional and recurrent view of the SSM to achieve parallelizable training and memory-efficient inference, respectively.

The number of parameters in an LSSL is determined by the following dimensions:

  • \(N\) - The State Size, or the number of dimemsions of the state vector.
  • \(H\) - The Hidden Size, or the number of dimensions of each input vector (i.e. the length along the Channel Dimension). When \(H\) is greater than 1, we essentially have \(H\) independent State Space Models, with each one computing a single element in the output vector for a given time step.
  • \(M\) - The number of Output Channels.

The weights of the LSSL consist of the familiar 4 variables from the definition of a State Space Model, which here take on the following shapes:

In addition to these four parameters, LSSL also makes the timestep \(\Delta\) a trainable parameter of size \(H\); and finally, because the \(C\) and \(D\) parameters include the additional Output Channels, or \(M\), dimension, a feedforward layer is added after the State Space Model to project the output back to the appropriate shape.

Implementing an LSSL

Let’s see the LSSL in action with a simple PyTorch implementation. We will first implement the contructor which initializes the LSSL’s parameters:

class LSSL(torch.nn.Module):
    def __init__(self, hidden_size, num_channels, state_size, initial_dt=0.01):
        super(LSSL, self).__init__()

        self.hidden_size = hidden_size
        self.num_channels = num_channels
        self.state_size = state_size
        
        # Initialize A and B from the HiPPO-LegS Matrix
        A, B = hippo_matrix(self.state_size)

        # Continuous-time parameters
        self.A = torch.nn.Parameter(A)
        self.B = torch.nn.Parameter(B)
        self.dt = torch.nn.Parameter(torch.ones(self.hidden_size) * initial_dt)

        self.C = torch.nn.Parameter(torch.randn(self.hidden_size, self.num_channels, self.state_size))
        self.D = torch.nn.Parameter(torch.randn(self.hidden_size, self.num_channels))
        self.output_proj = torch.nn.Parameter(torch.randn(self.hidden_size * self.num_channels, self.hidden_size))

The forward pass will contain two separate implementations: the recurrent implementation and the convolutional implementation. We will start with the recurrent implementation, as this is a little more intuitive.

Recurrent View

First, we convert the \(A\) and \(B\) parameters from their continuous to their discrete representation using the Generalized Bilinear Transform mentioned earlier:

def forward(self, input, mode='recurrent'):

    # Discretize the continuous-time model using the generalized bilinear transform
    A_discrete, B_discrete = gbt(self.A, self.B, self.dt)
    ...

Next we define some variables based on the shape of the input, and then initialize the state vectors to 0:

    ...
    batch_size, sequence_length = input.shape[0], input.shape[1]

    # Initialize hidden state to zeros
    x = torch.zeros(batch_size, self.hidden_size, self.state_size)
    ...

And now for the actual recurrent loop:

    ...
    if mode == 'recurrent':
        # List to store the output at each time step
        outputs = []

        for i in range(sequence_length):

            current_input = input[:, i, :]
            
            # Update the current state
            Ax = (A_discrete @ x.unsqueeze(-1)).reshape(*x.shape)
            Bu = torch.einsum('hn, bh -> bhn', B_discrete, current_input)
            x = Ax + Bu

            # Compute the output at the current time step
            Cx = torch.einsum('hmn, bhn -> bhm', self.C, x)
            Du = torch.einsum('hm, bh -> bhm', self.D, current_input)
            y = Cx + Du
            
            outputs.append(y)
        
        y = torch.stack(outputs, dim=1)

        # Output projection
        output = y.view(batch_size, sequence_length, -1) @ self.output_proj
        return output

Convolutional View

To perform the forward pass as a convolution, we first implement a function to compute the kernel \(K\):

def compute_convolution_kernel(self, sequence_length, A_discrete, B_discrete, C):
    K = [B_discrete.unsqueeze(-1)]
    for _ in range(1, sequence_length):
        K.append(A_discrete @ K[-1])

    K.reverse()
    K = torch.stack(K, dim=-1).squeeze(-2)
    return (C @ K).view(self.hidden_size * self.num_channels, 1, -1)

Next we pass \(K\) and the entire input sequence together into a 1-dimensional convolution operation:

...
if mode == 'convolutional':
    K = self.compute_convolution_kernel(sequence_length, A_discrete, B_discrete, self.C)

    conv_input = torch.nn.functional.pad(input.transpose(1, -1), (sequence_length-1, 0))

    y = torch.conv1d(
        input=conv_input,   # [B, H, L]
        weight=K,           # [ H * M, 1, L]
        groups=self.hidden_size
    )
    y = y.transpose(1, 2).reshape(batch_size, sequence_length, self.hidden_size, self.num_channels)
    
    Du = torch.einsum('hm, blh -> blhm', self.D, input)
    y += Du

    # Output projection
    output = y.view(batch_size, sequence_length, -1) @ self.output_proj
    return output

Stacking LSSLs

The LSSL defined above is analogous to a transformer block in a transformer model. That is, we can construct a larger model by stacking multiple LSSLs one after another.

The original implementation of deep LSSL models included residual connections and layer normalization. Something like this:

class LSSLModel(torch.nn.Module):

    def __init__(self, config: LSSLConfig):
        super(LSSLModel, self).__init__()
        self.blocks = torch.nn.ModuleList([LSSL(config) for _ in range(config.num_layers)])
        self.norm = torch.nn.LayerNorm(config.hidden_size)
        self.dropout = torch.nn.Dropout(config.dropout)

    def forward(self, u, mode='recurrent'):
        for block in self.blocks:
            residual = u
            u = self.norm(u)
            u = block(u, mode)
            u = self.dropout(u) + residual
        return u

The Drawbacks of the LSSL

Two variants of the LSSL were evaluated: one in which all of the parameters discussed above were trainable, and another, called LSSL-f, in which the \(A\) and the time step \(\Delta\) were frozen.

It was observed that training both \(A\) and \(\Delta\) resulted in better quality, but greatly increased the computational complexity of the model, as the convoluton kernel \(K\) had to be recomputed for each input.

The LSSL-f variant, on the other hand, was much more efficient, as much of the kernel computation could be performed once and cached; however, this variant produced lower-quality results.

Structured State Spaces (S4 | DSS | S4D)

Later work aimed to improve on the LSSL by simplifying the computation of the convolutional kernel \(K\). This led to the introduction of Structured State Spaces for Sequence Modeling, or S4[6].

If we take another look at the code to compute the kernel \(K\), we can see that the main computational bottleneck in this function is the for loop over the sequence length where we repeatedly multiply \(A\) by itself sequence_length times:

def compute_convolution_kernel(self, sequence_length, A_discrete, B_discrete, C):
    K = [B_discrete.unsqueeze(-1)]
    for _ in range(1, sequence_length):
        K.append(A_discrete @ K[-1])

    K.reverse()
    K = torch.stack(K, dim=-1).squeeze(-2)
    return (C @ K).view(self.hidden_size * self.num_channels, 1, -1)

Computing \(K\) with Diagonal Matrices

One way we could simplify this would be to force \(A\) to be diagonal (i.e. a matrix with 0s everywhere except along the main diagonal). In this case, raising \(A\) to the \(\ell^{th}\) power is equivalent to simply raising each element along \(A\)’s diagonal to the \(\ell^{th}\) power.

For example,

$$\begin{align*} \Lambda &= \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 2 & 0 & 0 \\ 0 & 0 & 3 & 0 \\ 0 & 0 & 0 & 4 \\ \end{bmatrix} &&&& \Lambda^{2} &= \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 2 & 0 & 0 \\ 0 & 0 & 3 & 0 \\ 0 & 0 & 0 & 4 \\ \end{bmatrix} \times \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 2 & 0 & 0 \\ 0 & 0 & 3 & 0 \\ 0 & 0 & 0 & 4 \\ \end{bmatrix} &= \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 4 & 0 & 0 \\ 0 & 0 & 9 & 0 \\ 0 & 0 & 0 & 16 \\ \end{bmatrix} \end{align*} $$

Let’s imagine more generally what would happen if we plugged a diagonal \(A\) matrix into the definition of \(K\):

$$K = (CA^{t}B, CA^{t-1}B, CA^{t-2}B,...,CB)$$

We will refer to the elements of \(A\), \(B\), and \(C\) as follows:

$$\begin{align*} A &= \begin{bmatrix} a_1 & 0 & 0 & \cdots & 0 \\ 0 & a_2 & 0 & \cdots & 0 \\ 0 & 0 & a_3 & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & \cdots & a_N \\ \end{bmatrix} & B &= \begin{bmatrix} b_1 \\ b_2 \\ b_3 \\ \vdots \\ b_N \\ \end{bmatrix} & C &= \begin{bmatrix} c_1 & c_2 & c_3 & \cdots & c_N \\ \end{bmatrix} \end{align*} $$

Given this, if we evaluate \(CA^{\ell}B\) for some power \(\ell\), then we see this comes out to:

$$\begin{align*} CA^{\ell}B &= \begin{bmatrix} c_1 & c_2 & c_3 & \cdots & c_N \end{bmatrix} \times \begin{bmatrix} a_1^{\ell} & 0 & 0 & \cdots & 0 \\ 0 & a_2^{\ell} & 0 & \cdots & 0 \\ 0 & 0 & a_3^{\ell} & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & \cdots & a_N^{\ell} \\ \end{bmatrix} \times \begin{bmatrix} b_1 \\ b_2 \\ b_3 \\ \vdots \\ b_N \\ \end{bmatrix} \\ CA^{\ell}B &= \begin{bmatrix} c_1a_1^{\ell} & c_2a_2^{\ell} & c_3a_3^{\ell} & \cdots & c_Na_N^{\ell} \end{bmatrix} \times \begin{bmatrix} b_1 \\ b_2 \\ b_3 \\ \vdots \\ b_N \\ \end{bmatrix} \\ CA^{\ell}B &= \sum_{i=0}^{N} c_{i}a_{i}^{\ell}b_{i} \end{align*} $$

Thus, we can write the computation of the entire kernel \(K\) as the following vector-matrix multiplication:

$$\begin{align*} K &= \begin{bmatrix} c_1b_1 & c_2b_2 & c_3b_3 & \cdots & c_Nb_N \end{bmatrix} \times \begin{bmatrix} 1 & a_1 & a_1^{2} & \cdots & a_1^{\ell} \\ 1 & a_2 & a_2^{2} & \cdots & a_2^{\ell} \\ 1 & a_3 & a_3^{2} & \cdots & a_3^{\ell} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & a_N & a_N^{2} & \cdots & a_N^{\ell} \\ \end{bmatrix} \end{align*} $$

This structure of the \(A\) matrix is known as a Vandermonde matrix, where each column contains the elements of \(A\) raised to the same power. i.e.

$$\begin{align*} A &= \begin{bmatrix} a_1 & 0 & 0 & \cdots & 0 \\ 0 & a_2 & 0 & \cdots & 0 \\ 0 & 0 & a_3 & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & \cdots & a_N \\ \end{bmatrix} & & & & \text{Vandermonde}(A) = \begin{bmatrix} 1 & a_1 & a_1^{2} & \cdots & a_1^{\ell} \\ 1 & a_2 & a_2^{2} & \cdots & a_2^{\ell} \\ 1 & a_3 & a_3^{2} & \cdots & a_3^{\ell} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & a_N & a_N^{2} & \cdots & a_N^{\ell} \\ \end{bmatrix} \end{align*}$$

If you’re more partial to Python code, below is the original function used to compute \(K\) (when \(A\) was not a diagonal matrix), compared to the simplified version of the computation made possible when \(A\) is forced to be diagonal:

hidden_size = 1 # For simplicity, assume that we are only running one copy of the SSM
sequence_length = 50
state_size = 16

# Restrict A to be diagonal
A = torch.diag_embed(torch.randn(1, state_size), dim1=-2, dim2=-1)
B = torch.randn(hidden_size, state_size)
C = torch.randn(hidden_size, state_size)

def compute_convolution_kernel(sequence_length, A, B, C):
    K = [B.unsqueeze(-1)]
    for _ in range(1, sequence_length):
        K.append(A @ K[-1])
    K = torch.stack(K, dim=-1).squeeze(-2)
    return (C @ K).view(H, 1, -1)


def compute_convolution_kernel_diag(sequence_length, A, B, C):
    # Construct the vandermonde matrix out of A using torch.vander
    VA = torch.vander(A.squeeze(0).diagonal(), increasing=True, N=sequence_length)
    return ((B * C) @ VA)

If we execute these functions on the same inputs, we see that both return the same result:

K_original = compute_convolution_kernel(sequence_length, A, B, C)
K_vandermonde = compute_convolution_kernel_diag(sequence_length, A, B, C)
print(f"Results Match: {torch.allclose(K_original, K_vandermonde)}")
Results Match: True

The new implementation, though, which takes advantage of \(A\) being diagonal is much more computationally efficient, especially as the length of the input sequence grows:

from timeit import timeit

def measure_kernel_compute_time(sequence_length, A, B, C):
    a = timeit(lambda: compute_convolution_kernel(sequence_length, A, B, C), number=1000)
    b = timeit(lambda: compute_convolution_kernel_diag(sequence_length, A, B, C), number=1000)
    print(f"Sequence Length {sequence_length} – Original: {a:.3f}s; Vandermonde: {b:.3f}s")

measure_kernel_compute_time(sequence_length=50, A=A, B=B, C=C)
print("-"*80)
measure_kernel_compute_time(sequence_length=500, A=A, B=B, C=C)
print("-"*80)
measure_kernel_compute_time(sequence_length=2000, A=A, B=B, C=C)
Sequence Length 50  Original: 0.254s; Vandermonde: 0.013s
--------------------------------------------------------------------------------
Sequence Length 500  Original: 2.381s; Vandermonde: 0.024s
--------------------------------------------------------------------------------
Sequence Length 2000  Original: 9.508s; Vandermonde: 0.063s

Finding a Diagonal \(A\) Matrix

In the previous section we saw that by forcing \(A\) to be a diagonal matrix, we can compute the convolution kernel \(K\) much more efficiently. But this begs the question: How can we replace \(A\) with a diagonal matrix without degrading the quality of the model?

The search for an \(A\) matrix that is diagonal yet still effective at preserving information across many time steps is essentially what S4[6], its descendent S4D[10], and its cousin(?) DSS[11] all attempt to do.

There is a clever trick we can use to convert one SSM into an equivalent SSM that forces \(A\) to be diagonal. But first, we need to introduce a few concepts:

  1. A matrix is said to be Diagonalizable if it is possible to reconstruct it from a diagonal matrix and an invertible matrix. That is, we can say that \(A\) is diagonalizable if for some diagonal matrix \(\Lambda\), and some invertible matrix \(P\), the following are true:
$$\begin{aligned} A &= P \Lambda P^{-1} \\ \Lambda &= P^{-1}AP \end{aligned}$$
  1. A State Space Transformation is an operation by which we can take one State Space Model and map it to a different but equivalent State Space Model. We perform a State Space Transformation by defining a new state vector \(\tilde{x}\) based on the state vector \(x\) of our existing SSM like so:
$$\tilde{x} = P^{-1}x$$

where \(P\) is any invertible matrix that we choose. Because \(P\) is invertible, we can flip this equation around and rewrite it in terms of \(x\):

$$x = P\tilde{x}$$

Now, to construct our new (but equivalent) state space model, we simply plug this back into the original definition of an SSM so that all of the \(x\) terms are replaced with \(\tilde{x}\):

$$\begin{align*} x_{t} &= Ax_{t-1} + Bu_{t} & y_{t} &= Cx_{t} + Du_{t} && (\Leftarrow \text{Original definition}) \\ P^{-1}x_{t} &= P^{-1}Ax_{t-1} + P^{-1}Bu_{t} & y_{t} &= Cx_{t} + Du_{t} && (\Leftarrow \text{Multiply state equation by } P^{-1}) \\ \tilde{x}_{t} &= P^{-1}AP\tilde{x}_{t-1} + P^{-1}Bu_{t} & y_{t} &= CP\tilde{x}_{t} + Du_{t} && (\Leftarrow \text{Substitute } P\tilde{x} \text{ for } x) \\ \end{align*}$$

This leaves us with a new equivalent SSM definition whose parameters have been redefined like so:

$$\text{SSM}( A, B, C, D ) \sim \text{SSM}( P^{-1}AP, P^{-1}B, CP, D )$$

The video “Transforming State-space Coordinates” offers a good introduction to this kind of State Space Transformation.

Now, putting these two ideas together, if we assume that the \(A\) matrix in an SSM is Diagonalizable (which, as it turns out, is the case in most SSMs), then we can perform a State Space Transformation as defined above to construct a new SSM whose \(A\) matrix is diagonal.

Specifically, if we perform the transformation using the same invertible matrix \(P\) that satisfies the diagonalization equation \(\Lambda = P^{-1}AP\), then we end up with:

$$\text{SSM}( A, B, C, D ) \sim \text{SSM}( \Lambda, P^{-1}B, CP, D )$$

This is all a very long-winded way of stating that almost all SSMs are equivalent to an SSM with a diagonal \(A\) matrix.

Finding the Right Diagonal \(A\) Matrix

Now that we know how to convert one SSM into an equivalent SSM with a diagonal matrix for \(A\), the tempting thing to do is to diagonalize our trusty HiPPO matrix and enjoy the newfound computational efficiency of working with diagonal matrices.

Sadly, the diagonalized version of the HiPPO matrix does not play nicely due to the numerical instability of the \(P\) matrix involved in the computation of \( P^{-1}AP \).

We can see this first hand by diagonalizing the HiPPO matrix using SymPy’s diagonalize() function:

from sympy import Matrix

# Get HiPPO A Matrix using the helper function defined before
A, _ = hippo_matrix(state_size=16)

# Convert A to a sympy matrix
A = Matrix(A)

# Diagonalize A to obtain P and the diagonal matrix Lambda
P, Lambda = A.diagonalize()

# Compute the inverse of P
Pinv = P.inv()

# Get the absolute value of the maximum value in P and Pinv
P_real, Pinv_real = P.as_real_imag()[0], Pinv.as_real_imag()[0]
max_value = max([
    max(P_real), max(Pinv_real)
])
print(f"Max Value: {max_value:.3f}")
Max Value: 4390273247.525

As we can see, even with a very small state size of 16, the max value in \(P\) and \(P^{-1}\) is already quite large, and will only continue to grow exponentially as the state size increases. For this reason, it is not practical to use the HiPPO matrix as-is; a more stable diagonalizable matrix is required.

Working with Complex Numbers

You may have noticed in the previous code example that before checking the max values, we first called the as_real_imag() function to separate the real and imaginary parts of any complex numbers in P and P_inv.

This is necessary because the \(P\) matrix and its inverse (as will as the diagonal \(\Lambda\) matrix) will often contain complex numbers of the form \(\alpha + \beta i\), where \( i \) is the imaginary unit and \( \alpha \) and \( \beta \) are real floating point values referred to respectively as the Real Part and the Imaginary Part of the complex number.

This is why, when you look at reference implementations of the S4 or S4D model architecture, you will see the SSM parameters stored in torch.cfloat (complex float) type. It’s worth keeping in mind that each value of cfloat type consists of two floating point numbers (the real part and the imaginary part); thus each value is 64 bits in total:

import torch
t = torch.randn(3, dtype=torch.cfloat)
print(f"Element Size: {t.element_size()*8} bits")
Element Size: 64 bits

S4D-Lin Initialization

With the S4D model, a specific diagonal \(A\) matrix, termed S4D-Lin, is derived, where each element along the main diagonal has a real part of \(-\frac{1}{2}\) and imaginary parts that increase linearly as we progress along the diagonal:

$$A_n = -\frac{1}{2} + i \pi n$$

In PyTorch, the creation of \(A\) according to S4D-Lin would look like this:

import math, torch
state_size = 4
s4d_lin_A = torch.tensor([-0.5 + (math.pi * n * 1j) for n in range(state_size)])
print(torch.diag_embed(s4d_lin_A))
tensor([[-0.5000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j],
        [ 0.0000+0.0000j, -0.5000+3.1416j,  0.0000+0.0000j,  0.0000+0.0000j],
        [ 0.0000+0.0000j,  0.0000+0.0000j, -0.5000+6.2832j,  0.0000+0.0000j],
        [ 0.0000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j, -0.5000+9.4248j]])

While \(A\) is technically still a two-dimensional matrix according to the definition of a State Space Model, in practice, we will usually represent the diagonal \(A\) as a one-dimensional vector of the elements along the diagonal, since we know that all other elements are 0.

In addition to initializing \(A\) to a stable diagonal matrix, it was observed by Goel, Karan, et al.[12] that the model can still become unstable during training if the values of \(A\) are not constrained appropriately. Specifically, if the real parts of \(A\) become positive during training, then the model can become unstable.

Thus, an activation function is applied to \(A\) prior to computing the convolution kernel \(K\) to ensure that the values in \(A\) have a negative real part. The activation function mentioned in the S4D paper is A.real = -1 * torch.exp(A.real) (We correspondingly initialize the real part of \(A\) as A.real = torch.log(A.real).

In other words:

def __init__(self, state_size):
    ...
    # Initialize A using the S4D-Lin Method
    A = s4d_lin_initialization(state_size)
    self.log_A_real = torch.nn.Parameter(torch.log(A.real))
    self.A_imag = torch.nn.Parameter(A.imag)
    ...

def forward(self, input):
    A = -torch.exp(self.log_A_real) + 1j * self.A_imag
    ...

Implementing S4D

Aside from the differences in the initialization of \(A\) and the computation of \(K\), much of the S4D implementation is the same the LSSL implementation. The S4D code is ommitted here for brevity. Those who are interested can find the full implementation at: s4d.py

Hungry Hungry Hippos (H3)

The Hungry Hungry Hippos (H3) model marks a departure from the sort of deep neural network architecture we saw in LSSL and S4, where each layer in the model is an SSM.

Instead, in the H3 architecture, multiple SSMs are incorporated into a broader, more attention-like operation that aims to recover the expressivity of self-attention while maintaining the memory-friendly fixed state size of SSMs.

Each H3 Layer contains two distinct SSMs that are characterized by the structure of their respective \(A\) matrices: SSM-Shift and SSM-Diag.

SSM-Shift uses a fixed \(A\) matrix of the following form:

$$A_{nk} = - \begin{cases} 1 & \text{if } n - 1 = k \\ 0 & \text{otherwise} \end{cases}$$

This matrix, when multiplied by the SSM state vector \(x_{t}\), has the effect of shifting all elements in the state vector to the right. We can see this in action by creating the SSM-Shift \(A\) matrix and executing it a few times on an arbitrary state vector:

# Construct the A matrix of SSM-Shift
def shift_matrix(state_size):
    A = torch.ones(state_size-1)
    return torch.diag(A, diagonal=-1)

shift_A = shift_matrix(5)
print("Shift matrix:")
print(shift_A)

state = torch.arange(1, 6, dtype=torch.float32)
print("\nInitial state vector:", state)

for _ in range(5):
    state = shift_A @ state
    print("state = shift_A @ state -", state)
Shift matrix:
tensor([[0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.]])

Initial state vector: tensor([1., 2., 3., 4., 5.])
state = shift_A @ state - tensor([0., 1., 2., 3., 4.])
state = shift_A @ state - tensor([0., 0., 1., 2., 3.])
state = shift_A @ state - tensor([0., 0., 0., 1., 2.])
state = shift_A @ state - tensor([0., 0., 0., 0., 1.])
state = shift_A @ state - tensor([0., 0., 0., 0., 0.])

SSM-Diag, the second type of SSM, is essentially the same as the S4D block discussed earlier: its \(A\) matrix is initialized to a specific stable diagonal matrix (e.g. using the S4D-Lin method).

The H3 Layer incorporates SSM-Shift and SSM-Diag into an attention-like operation that projects the input sequence into separate Query, Key, and Value signals:

A minimal implementation of the H3 Layer is shown below:

class H3(torch.nn.Module):

    def __init__(self, config: H3Config):
        super(H3, self).__init__()

        self.state_size = config.state_size
        self.hidden_size = config.hidden_size

        self.ssm_shift = SSM(config, init_A="shift")
        self.ssm_diag = SSM(config, init_A="s4d-lin")

        self.q_proj = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.k_proj = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.v_proj = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.out_proj = torch.nn.Linear(self.hidden_size, self.hidden_size)

    def forward(self, input, mode='recurrent'):
        q = self.q_proj(input)
        k = self.k_proj(input)
        v = self.v_proj(input)

        k = self.ssm_shift(k, mode)
        kv = k * v
        kv = self.ssm_diag(kv, mode)

        qkv = q * kv
        return self.out_proj(qkv)

The implementation of the SSM class referenced above is mostly the same as that of the S4D block from the previous section. The only real difference is in the alternative initialization logic used to support SSM-Shift:

class SSM(torch.nn.Module):
    """Generic SSM to support both Shift and Diag models."""

    def __init__(self, config: H3Config, init_A="s4d-lin"):
        ...
        self.init_A = init_A
        if self.init_A == "s4d-lin":
            # Initialize A using the S4D-Lin Method
            A = torch.tensor([
                0.5 + (math.pi * n * 1j) for n in range(self.state_size // 2)
            ]).repeat(self.hidden_size, 1)
            self.A_imag = torch.nn.Parameter(A.imag)
            A = torch.log(A.real)
            
            B = torch.randn(self.hidden_size, self.state_size // 2, dtype=torch.cfloat)
            B = torch.view_as_real(B)

            C = torch.randn(self.hidden_size, self.state_size // 2, dtype=torch.cfloat)
            C = torch.view_as_real(C)

        elif self.init_A == "shift":
            # Initialize A to a Shift matrix with 1s just below the main diagonal
            A = torch.diag(torch.ones(self.state_size-1), diagonal=-1)
            B = torch.randn(self.hidden_size, self.state_size)
            C = torch.randn(self.hidden_size, self.state_size)
        ...

H3 Layers like this are interleaved with MLPs to form deep H3 models like the ones available on HuggingFace: H3-125M.

These models are very similar to transformer models in terms of their architecture: the only real difference is that the Time Mixing operation in Transformers is Self-Attention, whereas the Time Mixing operation in H3 Models is the H3 Layer.

Mamba

The Mamba series of models received a lot of attention when they first launched, as they demonstrated that language models based on SSMs can match or outperform similar-size transformer models. The main distinction between Mamba and its predecessors is the shift from Time-Invariant SSMs to Time-Variant SSMs.

Recall the equations we use to describe a state space model:

$$\begin{aligned} x_{t} &= Ax_{t-1} + Bu_{t} \\ y_{t} &= Cx_{t} + Du_{t} \\ \end{aligned}$$

This is considered a Time-Invariant State Space Model, because the parameters \( (A, B, C, D) \) are the same at every time step. In contrast, Time-Variant State Space Models are described by:

$$\begin{aligned} x_{t} &= A(t)x_{t-1} + B(t)u_{t} \\ y_{t} &= C(t)x_{t} + D(t)u_{t} \\ \end{aligned}$$

Here, the SSM parameters are functions time \(t\), and can thus vary from one time step to the next. The Mamba architecture implements a slight variation on this notion of the time-variant SSM.

Specifically, in Mamba, the \(B\), \(C\), and \(\Delta\) parameters are computed at each time step based on the current input \(u\). This means that, instead of directly learning the parameters \(B\), \(C\), and \(\Delta\), we instead learn the parameters of the linear projections \(s_B(u)\), \(s_C(u)\), \(s_{\Delta}(u)\), which, when provided an input vector \(u\), produce the SSM parameters \(B\), \(C\), and \(\Delta\), respectively.

A notable consequence of this change is that now the convolutional view of the SSM, which the LSSL, S4, and H3 models all depended upon to parallelize the training process, is no longer applicable. To make up for this, the authors introduce several efficiency optimizations for the recurrent view of computing an SSM’s output.

The Mamba Block

The Mamba block – aside from its use of time-dependent \(B\), \(C\), and \( \Delta \) parameters – is essentially a simplified version of the H3 block:

  • H3 relied on two separate SSMs: SSM-Shift and SSM-Diag. SSM-Shift had a very particular \(A\) matrix that was not learned during training; \(B\) and \(C\) were, therefore, the only learnable parameters involved in the computation of the convolution kernel \(K\). Mamba replaces the Shift SSM with a simple 1D convolution.

  • The Query, Key, and Value projections from H3 are replaced by two projections in Mamba (these projections can be thought of as Query and Key, although this language is not used generally).

  • In multi-layered Mamba models, the Mamba blocks are stacked one after another, without any MLP interleaved between them.

A visual representation of the Mamba Block is shown below:

Implementing Mamba

A working implementation of Mamba is available in the HuggingFace Transformers repository (modeling_mamba.py); however, I found this implementation to be a bit hard to follow at first. Here we will create a minimal Mamba implementation that will (hopefully) be easier to understand.

First we define a config class which will store the various model dimensions:

@dataclass
class MambaConfig:
    # 𝑁 – The SSM state vector size
    state_size: int
    
    # 𝐷 - The input embedding dimension
    hidden_size: int
    
    # 𝐸 – Model dimension expansion factor 
    expansion_factor: int

    # 𝑅 - The time step rank
    dt_rank: int

    # Kernel size used in the 1D convolution
    conv_kernel_size: int

There are a few things here that have not been discussed yet:

  • First, the Expansion Factor \(E\). Similar to the MLP in most transformer models, the Mamba block projects the input embeddings into a higher dimension \(E \cdot D\), where \(D\) is the input embedding dimension. The Mamba block output is projected back down to size \(D\) by the output projection layer.
  • The time step \( \Delta \) is first projected to a lower rank \(R\) (i.e. \( R \ll ED \)) before getting projected into the internal dimension \(ED\).
  • The conv_kernel_size parameter controls the kernel size of the 1D convolution which is used by Mamba in place of the Shift SSM that was used in H3.

Next we will implement the main Mamba class, which includes the input/output projections, activation functions, and the 1D convolution:

class Mamba(torch.nn.Module):

    def __init__(self, config: MambaConfig):
        super(Mamba, self).__init__()

        self.state_size = config.state_size
        self.hidden_size = config.hidden_size * config.expansion_factor
        self.conv_kernel_size = config.conv_kernel_size

        self.conv = torch.nn.Conv1d(
            in_channels=self.hidden_size,
            out_channels=self.hidden_size,
            kernel_size=self.conv_kernel_size,
            padding=self.conv_kernel_size - 1,
            groups=self.hidden_size
        )
        self.ssm = MambaSSM(config)

        self.q_proj = torch.nn.Linear(config.hidden_size, self.hidden_size, bias=False)
        self.k_proj = torch.nn.Linear(config.hidden_size, self.hidden_size, bias=False)
        self.out_proj = torch.nn.Linear(self.hidden_size, config.hidden_size, bias=False)

    def forward(self, input):
        sequence_length = input.shape[1]

        # These are referred to as 'hidden_states' and 'gate' in the HuggingFace
        # implementation. I use 'q' and 'k' here to keep things consistent with H3 from before.
        q = self.q_proj(input)
        k = self.k_proj(input)

        conv_out = self.conv(k.transpose(1, 2))[..., :sequence_length]
        conv_out = torch.nn.functional.silu(conv_out)
        ssm_out = self.ssm(conv_out.transpose(1, 2))

        output = ssm_out * torch.nn.functional.silu(q)
        return self.out_proj(output)

The implementation of the MambaSSM will be similar to S4D and the diagonal SSM in H3. With Mamba, however, we use real (as opposed to complex) values for the SSM parameters. Here is the __init__ method:

class MambaSSM(torch.nn.Module):

    def __init__(self, config: MambaConfig):
        super(MambaSSM, self).__init__()

        self.hidden_size = config.hidden_size * config.expansion_factor
        self.state_size = config.state_size

        self.dt_proj_in = torch.nn.Linear(self.hidden_size, config.dt_rank, bias=False)
        self.dt_proj_out = torch.nn.Linear(config.dt_rank, self.hidden_size, bias=True)

        self.B_proj = torch.nn.Linear(self.hidden_size, self.state_size, bias=False)
        self.C_proj = torch.nn.Linear(self.hidden_size, self.state_size, bias=False)
        self.D = torch.nn.Parameter(torch.randn(self.hidden_size))

        A = torch.arange(1, self.state_size + 1).repeat(self.hidden_size, 1).float()
        self.A = torch.nn.Parameter(torch.log(A))

The Discretization method used by Mamba is slightly different from what we have seen before. Instead of the Bilinear Transform method, here we use something more akin to Zero-Order Hold:

$$\begin{aligned} A &= e^{\Delta \mathring{A} } & & & B = \Delta\mathring{B} \end{aligned}$$

Note: If you scroll all the way back to where Zero-Order Hold was introduced earlier, you will see that the computation of \(B\) is a bit different from what we have here. The authors of the Mamba paper chose a simpler approximation of \(B\) which they found to produce similar results (see https://github.com/state-spaces/mamba/issues/19).

Here is the implementation for computing the \(A\) and \(B\) parameters in the MambaSSM class:

def forward(self, input):
    batch_size, sequence_length = input.shape[0], input.shape[1]

    A = -torch.exp(self.A)
    B = self.B_proj(input)
    C = self.C_proj(input)
    
    dt = self.dt_proj_in(input)
    dt = self.dt_proj_out(dt)
    dt = torch.nn.functional.softplus(dt)

    A_discrete = torch.exp(
        A.view(1, 1, self.hidden_size, self.state_size) * 
        dt.view(batch_size, sequence_length, self.hidden_size, 1)
    )
    B_discrete = (
        B.view(batch_size, sequence_length, 1, self.state_size) * 
        dt.view(batch_size, sequence_length, self.hidden_size, 1)
    )
    ...

Finally, because the convolution view of the SSM is not applicable to Mamba, we compute the outputs via recurrence:

def forward(self, input):
    ...
    outputs = []
    x = torch.zeros(batch_size, self.hidden_size, self.state_size)
    for i in range(sequence_length):

        # Select the input at index i in the sequence
        current_input = input[:, i, :]
        
        # Because A, B, and C are time-variant, we need to select the 
        # values for the current time step
        current_A_discrete = A_discrete[:, i, :, :]
        current_B_discrete = B_discrete[:, i, :, :]
        current_C = C[:, i, :]

        # Update the current state
        Ax = torch.einsum('bhn, bhn -> bhn', current_A_discrete, x)
        Bu = torch.einsum('bhn, bh -> bhn', current_B_discrete, current_input)
        x = Ax + Bu

        # Compute the output at the current time step
        Cx = torch.einsum('bnx, bhn -> bh', current_C.unsqueeze(-1), x)
        Du = torch.einsum('h, bh -> bh', self.D, current_input)
        
        y = Cx + Du
        outputs.append(y)
    return torch.stack(outputs, dim=1)

Wrapping Up

The research into sequence-to-sequence models based on SSMs is moving quickly, so, in all likelihood, by the time you are reading this, there is already something later and greater than all of the model architectures and techniques described here. Even so, I hope this looks at the fundamentals of state space models and their application to deep learning helps to demystify any future architectures that incorporate or adapt the concepts explored on this page.

All of minimal implementations of the different model architectures can be found in the following repo: https://github.com/gnovack/easy-ssm

Thanks for reading 🫡

References

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, Ł. and Polosukhin, I., 2017. Attention is all you need. Advances in neural information processing systems, 30.

  2. Tolstikhin, I.O., Houlsby, N., Kolesnikov, A., Beyer, L., Zhai, X., Unterthiner, T., Yung, J., Steiner, A., Keysers, D., Uszkoreit, J. and Lucic, M., 2021. Mlp-mixer: An all-mlp architecture for vision. Advances in neural information processing systems, 34, pp.24261-24272.

  3. Peng, B., Alcaide, E., Anthony, Q., Albalak, A., Arcadinho, S., Biderman, S., Cao, H., Cheng, X., Chung, M., Grella, M. and GV, K.K., 2023. Rwkv: Reinventing rnns for the transformer era. arXiv preprint arXiv:2305.13048.

  4. Dao, T. and Gu, A., 2024. Transformers are SSMs: Generalized models and efficient algorithms through structured state space duality. arXiv preprint arXiv:2405.21060.

  5. Gu, A., Dao, T., Ermon, S., Rudra, A. and Ré, C., 2020. Hippo: Recurrent memory with optimal polynomial projections. Advances in neural information processing systems, 33, pp.1474-1487.

  6. Gu, A., Goel, K. and Ré, C., 2021. Efficiently modeling long sequences with structured state spaces. arXiv preprint arXiv:2111.00396.

  7. Fu, D.Y., Dao, T., Saab, K.K., Thomas, A.W., Rudra, A. and Ré, C., 2022. Hungry hungry hippos: Towards language modeling with state space models. arXiv preprint arXiv:2212.14052.

  8. Gu, A. and Dao, T., 2023. Mamba: Linear-time sequence modeling with selective state spaces. arXiv preprint arXiv:2312.00752.

  9. Gu, A., Johnson, I., Goel, K., Saab, K., Dao, T., Rudra, A. and Ré, C., 2021. Combining recurrent, convolutional, and continuous-time models with linear state space layers. Advances in neural information processing systems, 34, pp.572-585.

  10. Gu, A., Goel, K., Gupta, A. and Ré, C., 2022. On the parameterization and initialization of diagonal state space models. Advances in Neural Information Processing Systems, 35, pp.35971-35983.

  11. Gupta, A., Gu, A. and Berant, J., 2022. Diagonal state spaces are as effective as structured state spaces. Advances in Neural Information Processing Systems, 35, pp.22982-22994.

  12. Goel, K., Gu, A., Donahue, C. and Ré, C., 2022, June. It’s raw! audio generation with state-space models. In International Conference on Machine Learning (pp. 7616-7633). PMLR.