This post is mainly based on

There are 3 paradigms for modelling long time series: Continuous-Time Model (CTM), RNN and CNN

paradigms

Existing architectures cannot handle extremely long sequence (10000+ steps), due to:

  • CTM: hard to optimize
  • CNN: receptive field is fix; high inference memory; large receptive field requires large filter / deep architecture
  • RNN: forgetting problem
  • Transformer: $O(n^2)$ memory complexity for attention matrix

The Structured State Space Model (S4) unify CTM, CNN and RNN.

  • S4 has 3 different types of representations, corresponding to CTM, CNN and RNN
  • S4 can switch between the 3 representations

Background

Linear State Space Model (Linear SSM)

\[\begin{align} x'(t) &= Ax(t) + Bu(t) \\ y(t) &= Cx(t) + Du(t) \end{align}\]

where,

  • Input $u \in \mathbb{R}$
  • State $x \in \mathbb{R}^N$
  • $ A \in \mathbb{R}^{N \times N}, B \in \mathbb{R}^{N \times 1}, C \in \mathbb{R}^{1 \times N}, D\in \mathbb{R}^{1 \times 1} $

The above form is known as continuous time representation of SSM

For detail on linear SSM, please refer to discussions on Linear State Space Layer.

HiPPO

  • A class of matrices $A \in \mathbb{R}^{N \times N}$
  • $A$ allows the state $x(t)$ to memorize the history of input $u(t)$
  • LSSL adopt such matrix $A$ and outperform various SOTA model on long range time series benchmarks

3 Representations of SSM

The SSM has 3 different types of representations: Continuous-time, Recurrent and Convolutional.

ssm-3-form

Continuous-Time Representation

ssm-ctm-form

The continuous-time representation correspond to the function to function map: $u(t) \rightarrow y(t)$, which cannot be directly estimated from the discrete data $(u_0, y_0), …, (u_k, y_k)$.

Recurrent Representation
A continuous SSM($A,B,C,D$) can be discretize into SSM($\bar{A},\bar{B},\bar{C},\bar{D}$), where

\[\begin{align} x_k &= \bar{A}x_{k-1} + \bar{B}u_k \\ y_k &= \bar{C}x_k + \bar{D}u_k \end{align}\]

The above transformation between continuous and discrete form is well studied. Refer to discussions around Generalized Bilinear Transform (GBT).

To simplify the discussion, assume $D = 0$. The reason is $D$ can be viewed as a skip connection of input $u$ and is easy to estimate.

Let the step size between $u_t$ and $u_{t-1}$ be $\Delta$, then we have the discrete SSM:

\[\begin{align} \bar{A} &= (I - \Delta/2 \cdot A)^{-1}(I + \Delta/2 · A) \\ \bar{B} &= (I - \Delta/2 \cdot A)^{-1} \Delta B \\ \bar{C} &= C \end{align}\]

The Recurrent Representation of the SSM can be viewed as an RNN, as discussed in the Expressivity of LSSL.

Convolutional Representation
The recurrent representation can be unrolled into convolutional form: for $i \leq j$, each input $u_i$ is passed through $\bar{B},\bar{C}$ matrices 1 time and $\bar{A}$ matrix $j-i$ times

\[y_k = \bar{C}\bar{A}^k\bar{B} u_0 + \bar{C}\bar{A}^{k-1}\bar{B} u_1 + ... + \bar{C}\bar{A}\bar{B} u_{k-1} + (\bar{B}\bar{C})u_k\]

The above computation is equivalent to a CNN with a $L$-length 1D convolutional filter $\bar{K}$:

\[\bar{K} \in R^L := (\bar{B}\bar{C}, \bar{C}\bar{A}\bar{B}, ..., \bar{C}\bar{A}^{L-1}\bar{B})\] \[y = \bar{K} * u\]

Advantage of Each Representation

  • Continuous-Time: theoretical grounded and easy to analyze (e.g., long term memory, sampling resolution change)
  • Recurrent: constant memory requirement in autoregressive sequence generation
  • Convolutional: parallel training

Technical Challenges of SSM

Problem 1: $A$ need to be carefully initialized to enable SSM to learn long-range dependencies

As discussed in LSSL post, if $A$ is randomly initialized, SSM performs poorly in very simple sequence modelling task.

If $A$ is initialized as a HiPPO Matrix:

\[A_{nk} = \begin{cases} (2n+1)^{\frac{1}{2}}(2k+1)^{\frac{1}{2}} & \text{if } n>k \\ n+1 & \text{if } n=k \\ 0 & \text{if } n<k \end{cases}\] \[B_n = (2n+1)^{\frac{1}{2}}\]

The state sequence $x_0, …, x_k$ will try to memorize $u_0, …, u_k$.

Refer to HiPPO & Sequential Data Compression post.

On sequential MNIST problem,

  • $A \sim \mathcal{N}(0,1)$: 60% accuracy
  • $A \sim \operatorname{HiPPO}$: 98% accuracy

Problem 2: direct computation of the convolution kernel $\bar{K}$ is costly and numerically unstable

  • Computing $\bar{K}$ involves repeated matrix multiplication $\bar{A}^{L-1}$
    • This could be unstable if $\bar{A}$ has the largest eigenvalue not close to 1.
  • Automatic differentiation through $\bar{A}^{L-1}$ also leads to vanishing gradient problem
  • Computing $\bar{A}^{L-1}$ leads to $O(N^2L)$ complexity. Ideally, we want $O(L)$ complexity

Structured SSM (S4)

The LSSL paper attempted to address the Problem 2, however, the proposed method is unsuccessful:

  • Theorem 2 is proposed to efficiently compute convoluational kernel $\bar{K}$
  • However, the authors found that the algorithm implied by Theorem 2 is numerically unstable
  • For details, see discussions on Limitations of LSSL.

S4 proposed a new parameterization on how to efficiently compute kernel $\bar{K}$.

The authors discussed:

  • Conjugation and Diagonalization of $A$
  • Normal Plus Low-Rank Decomposition improves naive Diagonalization
  • Discussions on computation and space complexity

Conjugation / Diagonalization of HiPPO Matrix $A$

The problem with the CNN kernel $ \bar{K} := (\bar{B}\bar{C}, \bar{C}\bar{A}\bar{B}, …, \bar{C}\bar{A}^{L-1}\bar{B}) $, is the repeated matrix multiplication by $\bar{A}$.

The first attempt is to solve this by Conjugation.

Diagonalization

\[A=VDV^{-1}\]

where,

  • $D$ is diagonal
  • $V$ is composed of the eigenvectors of $A$

Jordan Canonical Form

\[A=VJV^{-1}\]

where,

  • $J$ is block diagonal
  • $V$ is composed of the generalized eigenvectors of $A$

Generalized Eigenvectors

  • A generalized eigenvector of order $k$ is a nonzero vector $v$ that satisfies
    • $(A-\lambda I)^kv=0$ for some eigenvalue $\lambda$
    • $(A-\lambda I)^{k-1}v \not=0$
    • The case where $k=1$ corresponds to regular eigenvectors
  • Why Generalized Eigenvectors?
    • Generalized eigenvectors is useful when $A$ does not have enough linearly independent eigenvectors to diagonalize it
    • This is often due to some eigenvalues having a higher algebraic multiplicity than geometric multiplicity
  • The generalized eigenvectors form what is known as the generalized eigenspace for the eigenvalue $\lambda$. They provide a basis for the Jordan blocks in the Jordan canonical form of $A$.
  • Jordan Canonical Form
    • Each Jordan block corresponds to an eigenvalue and is associated with a chain of generalized eigenvectors
    • The first vector in the chain is a regular eigenvector (generalized eigenvector of order 1)
    • The subsequent vectors in the chain are higher-order generalized eigenvectors

Conjugation

Let

  • $A$ be a square matrix
  • $V$ be an invertible matrix

The conjugate of $A$ by $V$, denoted $A’$, is defined as the matrix

\[A'=V^{-1}AV\]

where,

  • $A’$ is not necessary diagonal / block diagonal
  • $V$ is not necessarily composed of the eigenvectors of $A$

This process maps $A$ to a new matrix $A’$ using $V$:

  • $V^{-1}$ transforms the coordinate system, moving from the standard basis to the new basis defined by $V$
  • $A$ is then applied in this new basis, performing the transformation that $A$ represents
  • Finally, $V$ transforms the coordinate system back to the original basis

Properties

  • Conjugation is a change of basis operation in linear algebra that doesn’t alter the linear transformation’s action on the space
  • Hence, conjugation maintains the eigenvalues of $A$ and, more generally, its spectral properties

Relationships

  • Jordan Canonical Form is a special case of Conjugation
  • Diagonalization is a special case of Jordan Canonical Form

Lemma 3.1

Conjugation is an equivalence relation on SSMs $(A,B,C) \sim (V^{-1}AV, V^{-1}B, CV)$.

SSM$(A,B,C)$:

\[\begin{align} x'(t) &= Ax(t) + Bu(t) \\ y(t) &= Cx(t) \end{align}\]

SSM$(V^{-1}AV, V^{-1}B, CV)$:

\[\begin{align} \tilde{x}'(t) &= V^{-1}AV \tilde{x}(t) + V^{-1}Bu(t) \\ y(t) &= CV \tilde{x}(t) \end{align}\]

The above 2 systems are equivalent. You can verify this by substituting $ x(t) = V\tilde{x}(t) $ (change of basis of state $x$) for SSM$(A,B,C$), you will get SSM$(V^{-1}AV, V^{-1}B, CV)$.

Unfortunately, the naive application of diagonalization does not work due to numerical stability issues.

Normal Plus Low-Rank (NPLR) Decomposition

The numerical stability issues emerged from the discussion above implies that we should only conjugate by well-conditioned matrices $V$.

  • The ideal scenario is $A$ is diagonalizable by a perfectly conditioned matrix
    • i.e., unitary matrix / absolute value of eigenvalue = 1 / norm-preserving
  • By the Spectral Theorem of linear algebra, this requires $A$ to be normal matrices
  • However, HiPPO matrix is not normal matrix

However, HiPPO matrix can be decomposed as the sum of a normal matrix and a low-rank matrix.

This decomposed is still not useful by itself:

  • Unlike a diagonal matrix, powering up this sum of 2 matrices is still slow and not easily optimized
  • 3 techniques are required to make it useful
    • Truncated generating function
    • Woodbury identity
    • Cauchy kernel

We will skip details on NPLR since the author later discovered that the “easy to implement” diagonal state space model (S4D) is equally effective as S4.

Theorem 1

All HiPPO matrices have a Normal Plus Low-Rank (NPLR) representation:

\[A = V \Lambda V^∗ − P Q^\top = V (\Lambda - ( V^* P ) (V^* Q)^*) V^*\]

where

  • $V \in \mathbb{C}^{N \times N}$ is unitary
  • $\Lambda$ is diagonal
  • $P, Q \in \mathbb{R}^{N \times r}$ is low-rank factorization

S4 Computational Complexity

Theorem 2 (S4 Recurrence)

Given any step size $\Delta$, computing one step of the SSM recurrence can be done in $O(N)$ operations where $N$ is the state size

Theorem 3 (S4 Convolution / S4’s core technical contribution)

Given any step size $\Delta$, computing the SSM convolution filter $\bar{K}$ can be reduced to 4 Cauchy multiplies, requiring only $O(N + L)$ operations and $O(N + L)$ space

S4 Layer Architecture

Following the NPLR Decomposition, an S4 layer is parameterized as a SSM $ (\Lambda - PQ^*, B, C) $, which contains 5 trainable parameters.

Follow-up work found that this version of S4 can sometimes suffer from numerical instabilities when the A matrix has eigenvalues on the right half-plane. It introduced a slight change to the NPLR parameterization for S4 from $ \Lambda - PQ^* $ to $ \Lambda - P P^* $ that corrects this potential problem.

complexity

Complexity of various sequence models in terms of sequence length (L), batch size (B), and hidden dimension (H).

Experiments

Efficiency Benchmarks

efficiency

Left: The S4 parameterization with NPLR is asymptotically more efficient than the LSSL.

Right: S4 vs efficient Transformers. S4’s speed and memory use is competitive with the most efficient Transformer in a parameter-matched setting.

Long Range Dependency (LRD) Benchmarks

Long Range Arena (LRA) result-lra

Performance on Long Range Arena. Tasks are extremely challenging (e.g., Path-X requires reasoning about LRDs over sequences of length $128 \times 128 = 16384$). S4 significantly outperforms SOTA models. For reproducibility, read paper’s Appendix D.5.

Speech Classification

  • SC10 subset of the Speech Commands dataset (length-16000)
  • S4 achieves 98.3% accuracy, higher than all baselines that use the 100x shorter MFCC features
  • S4 outperforms WaveGAN-D, a CNN specifically designed for raw speech. WaveGAN-D has 90x more parameters and incorporating many architectural heuristics

result-speech

General Sequence Model

Experiments are to test model performance on general sequential data, including

  • Discriminative modelling on
    • Flattened image
    • Randomly permuted image (permutation is fixed)
  • Generative modelling on
    • Image
    • Text

Discriminative modelling / Pixel-level 1-D image classification

pmnist

Generative modelling / CIFAR-10 density estimation

cifar-den-est

CIFAR image are $32 \times 32$ pixel images with 3 color channels. Each image consists of $32 \times 32 \times 3 = 3072$ RGB values.

To compute bits per dim, each image is flattened and an auto-regressive model can predicted conditional probability for each of the 3072 steps:

\[p_1, p_2, ...., p_{3072}\]

Then we can compute the total log likelihood:

\[\text{Total Log Likelihood} = \sum_{i=1}^{3072} \log (p_i​)\]

A higher total log likelihood indicates that the model is doing a better job of predicting each subpixel, meaning it has learned a more accurate representation of the data distribution.

The total log likelihood is then normalized and convert to base 2:

\[\text{Bits per Dimension} = \frac{ \text{Total Log Likelihood} }{ 3072 \cdot \log(2)}\]

Generative modelling / WikiText-103 language modeling

wikitext

Ablations

To answer the question:

  • How important is the HiPPO initialization?
  • How important is training the SSM on top of HiPPO?

hippo-init

CIFAR-10 classification with unconstrained, real-valued SSMs with various initializations. Top: Train accuracy. Bottom: Validation accuracy. For all methods, training the SSM led to improvements in both training and validation curves. However, different initialization results in largely different generalization gaps.