This post is mainly based on

Parallel Scan

Scan Operation

Given a binary associative operator $\bullet$ and a sequence of $L$ elements $[a_1, a_2, …, a_L]$, the scan operation (or all-prefix-sum operation) returns the sequence:

\[\left[ a_1, \; (a_1 \bullet a_2), \; \ldots, \; (a_1 \bullet a_2 \bullet \ldots \bullet a_L) \right]\]

Parallel Scan Operation

The parallelization of scan operations has been well studied in

  • [Prefix sums and their applications, 1990]
  • [Parallel computing using the prefix problem, 1994]

Many standard scientific computing libraries contain efficient implementations.

Naive Parallel Scan

ps-naive

A naive algorithm take advantage of the parallel processing and have time complexity $O(\log_2 n)$.

However, this algorithm is not compute-efficient: while a sequential sum requires $n$ operations, the naive algorithm requires $O(n \log_2 n)$ operations. At level $d=1$, the algorithm assumes that there is $n-1$ parallel processes can be called, which is not realistic for large array.

Observe that there are a lot of redundant operations. For example, to compute $\sum (x_0, …, x_4)$. The naive algorithm performs 3 operations

  • $x_3 + x_4$
  • $(x_1 + x_2) + (x_3 + x_4)$
  • $x_0 + (x_1 + x_2) + (x_3 + x_4)$

A more efficient way is to wait until $\sum (x_0, …, x_3)$ is computed, the perform a single operation $\sum (x_0, …, x_3) + x_4$. To achieve this, we need to design an algorithm to determine which sum to execute first and which sum to hold.

Compute-Efficient Parallel Scan

To determine which the optimal sequence of operations, we build a balanced binary tree:

ps-up-sweep

In the Up-Sweep (Reduce) Phase, we perform pairwise operations at each level to compute partial sum for all even index elements and hold the intermediate result.

The root node $\sum(x_0, …, x_7)$ is therefore computed, since it holds the sum of all nodes in the array.

ps-down-sweep

In the Down-Sweep Phase, we perform a sequence of shuffle and addition operations, where:

  • At level $d=0$, initialize the root node to 0
  • At each subsequent level $d$
    • Identify $2^d$ nodes at this level
    • For each node, assume they hold intermediate values:
      • Node: $x_p$
      • Left child: $x_l$
      • Right child: $x_r$
    • Perform the following operations
      • Left child = $x_p$
      • Right child = $x_p + x_l$

The resulting sequence will be the all-prefix-sum excluding the last element, which was computed in the Up-Sweep Phase.

This algorithm

  • Has time complexity $2 \log_2 n$, if we can compute $n/2$ additions in parallel
  • Requires $2(n - 1)$ additions, $n - 1$ swaps

Note that

  • In this example, we use addition $+$, which is an associative operator
  • Parallel scan works on any associative operator
  • An associative operator is not necessarily commutative, therefore the right child must be computed in $x_p + x_l$, not $x_l + x_p$

Parallel Scan on SSM

Parallel Scan requires binary associative operator. Therefore, we need to propose an associative operator for SSM updates.

Let $\bullet$ be:

\[q_i \bullet q_j := ( q_{j,a} \odot q_{i,a}, \quad q_{j,a} \otimes q_{i,b} + q_{j,b} )\]

where

  • $\odot$: matrix-matrix multiplication
  • $\otimes$: matrix-vector multiplication
  • $+$: element-wise addition

It can be shown that $\bullet$ operator is associative.

To show what this $\bullet$ operator is equivalent to a state update in SSM, let’s consider the following example.

An Example

Consider a linear state update $x_k = \bar{A} x_{k-1} + \bar{B} u_k$, $u_{1:4}$ and $x_0=0$. We have:

\[\begin{align} x_1 &= \bar{B} u_1\\ x_2 &= \bar{A}\bar{B} u_1 + \bar{B} u_2\\ x_3 &= \bar{A}^2\bar{B} u_1 + \bar{A}\bar{B} u_2 + \bar{B} u_3\\ x_4 &= \bar{A}^3\bar{B} u_1 + \bar{A}^2\bar{B} u_2 + \bar{A}\bar{B} u_3 + \bar{B} u_4\\ \end{align}\]

Sequential Compute

To show why $\bullet$ operator is equivalent to a state update in SSM, define $c_{1:L}$, where each element $c_k$ is the tuple:

\[c_k = (c_{k,a}, \: c_{k, b}) := (\bar{A}, \: \bar{B} u_k )\]

Note that $c_{1:L}$ can be easily initialize prior to parallel scan.

To compute the recurrence sequentially, initialize $c_{1:4}$ and $s_0$ as:

\[\begin{align} c_{1:4} &:= ( (\bar{A}, \: \bar{B} u_1 ), (\bar{A}, \: \bar{B} u_2 ), (\bar{A}, \: \bar{B} u_3 ), (\bar{A}, \: \bar{B} u_4 ) ) \\ s_0 &:= (I, 0) \end{align}\]

where $I$ is identity matrix.

Then we have:

\[\begin{align} s_1 &= s_0 \bullet c_1 = (I, 0) \bullet (\bar{A}, \bar{B} u_1) = (\bar{A}, \bar{B} u_1) \\ s_2 &= s_1 \bullet c_2 = (\bar{A}^2, \bar{A}\bar{B} u_1 + \bar{B} u_2)\\ s_3 &= s_2 \bullet c_3 = (\bar{A}^3, \bar{A}^2\bar{B} u_1 + \bar{A}\bar{B} u_2 + \bar{B} u_3)\\ s_4 &= s_3 \bullet c_4 = (\bar{A}^4, \bar{A}^3\bar{B} u_1 + \bar{A}^2\bar{B} u_2 + \bar{A}\bar{B} u_3 + \bar{B} u_4)\\ \end{align}\]

Note that the second element in the tuple $s_i$ is the Linear SSM state update result.

Parallel Scan

Let’s compute this scan operation using parallel scan.

First execute the Up-Sweep Phase:

d = 1:

\[\begin{align} r_{2, d=1} &:= c_1 \bullet c_2 \\ r_{4, d=1} &:= c_3 \bullet c_4 \\ \end{align}\]

d = 2:

\[r_{4, d=2} := r_{2, d=1} \bullet r_{4, d=1} = (c_1 \bullet c_2) \bullet (c_3 \bullet c_4)\]

This is the final element of the parallel scan.

Then execute the Down-Sweep Phase:

Set $r_{4, d=0} = (I, 0)$

d = 1:

For node $r_{4, d=0} = (I, 0)$,

  • Left child $r_{2, d=1} = (I, 0)$
  • Right child $r_{4, d=1} = (I, 0) \bullet r_{2, d=1} = c_1 \bullet c_2$

d = 2:

For node $r_{4, d=1} = c_1 \bullet c_2$,

  • Left child $r_{3, d=1} = r_{2, d=1} = c_1 \bullet c_2$
  • Right child $r_{4, d=1} = r_{2, d=1} \bullet c_3 = (c_1 \bullet c_2) \bullet c_3$

For node $r_{2, d=1} = (I, 0)$,

  • Left child $r_{1, d=1} = (I, 0)$
  • Right child $r_{2, d=1} = (I, 0) \bullet c_1 = c_1$

Excluding the first element $(I, 0)$, we end up with the remaining $n-1$ elements of the parallel scan:

\[[c_1, c_1 \bullet c_2, c_1 \bullet c_2 \bullet c_3]\]

S5

  • S4 vs S5
    • S4 layer: many independent SISO (single input single output) SSMs
    • S5 layer: one MIMO (multi input multi output) SSM
  • S5
    • $A$ is diagonalized, similar to S4D
    • Recurrent computation only, no Convolution
    • Parallel scan can match computational efficiency of S4
    • Can handle time-varying SSMs and irregularly sampled observations
  • Results
    • SOTA on several long-range sequence modeling tasks

Background

Linear State Space Model

Continuous-time linear SSMs:

\[\begin{align} x'(t) &= Ax(t) + Bu(t)\\ y(t) &= Cx(t) + Du(t) \end{align}\]

Given a constant step size $\Delta$, the continuous SSM can be discretized:

\[\begin{align} x_k &= \bar{A}x_{k-1} + \bar{B}u_k\\ y_k &= \bar{C}x_k + \bar{D}u_k \end{align}\]

with Euler, bilinear or zero-order hold (ZOH) methods.

Linear SSM with Parallel Scan

A length $L$ linear recurrence of a discretized SSM can be computed by the scan operation.

Complexity

Given $L$ processors, the linear recurrence of the discretized SSM above can be computed in a parallel time of $O(T_\odot \log L)$, where $T_\odot$ is the cost of matrix-matrix multiplication.

For diagonal matrix $\bar{A} \in \mathbb{R}^{P \times P}$, the time cost of $T_\odot$ is $O(P\log L)$ and space cost is $O(PL)$.

S4 Layer

Given a sequence of length $L$ and channel $H$, the S4 layer defines a nonlinear sequence-to-sequence mapping:

  • From an input sequence: $u_{1:L} \in \mathbb{R}^{L \times H}$
  • To an output sequence: $u’_{1:L} \in \mathbb{R}^{L \times H}$

An S4 layer contains $H$ independent single-input, single-output (SISO) SSMs

  • Each SSM
    • Has $N$-dimensional states
    • Leverages the HiPPO framework for online function approximation
    • The HiPPO-LegS matrix can be reduce to a normal plus low-rank (NPLR/DPLR) form
  • Each SSM applied to a single dimension of the input sequence
    • This results in an independent linear transformation from each input channel to each preactivation channel
  • Then, a nonlinear activation function is applied to each preactivation channel
  • Finally, a position-wise linear mixing layer is applied to combine the activation
    • This results in the output sequence $u’_{1:L}$

s4-arch

S4 layer

S4 Learnable parameters

  • $H$ independent SSM
    • Input matrix: $B^{(h)} \in \mathbb{C}^{N \times 1}$
    • DPLR parameterized transition matrix $A^{(h)} \in \mathbb{C}^{N \times N}$
    • Output matrix $C^{(h)} \in \mathbb{C}^{1 \times N}$
    • Timescale parameter $\Delta(h) \in \mathbb{R}_+$
  • Mixing layer: $O(H^2)$ parameters

S5 Layer

The S5 layer replaces $H$ independent SISO SSMs with a multi-input, multi-output (MIMO) SSM.

s5-arch

S5 layer

S5 Parameterization

Given $A$ is diagonalizable: $A = V \Lambda V^{-1}$, where $\Lambda \in \mathbb{C}^{P \times P}$, then the continuous-time latent dynamics is also diagonalizable:

\[\frac{dV^{-1}x(t)}{dt} = \Lambda V^{-1}x(t) + V^{-1}Bu(t)\]

Let $\tilde{x}(t) = V^{-1}x(t)$, $\tilde{B} = V^{-1}B$, and $\tilde{C} = CV$, then the reparameterized system is:

\[\begin{align} \frac{d\tilde{x}(t)}{dt} &= \Lambda \tilde{x}(t) + \tilde{B}u(t)\\ y(t) &= \tilde{C}\tilde{x}(t) + Du(t) \end{align}\]

This is a linear SSM with a diagonal state matrix. This diagonalized system can be discretized with a timescale parameter $\Delta \in \mathbb{R}_+$ using the ZOH method to give another diagonalized system with parameters

\[\begin{align} \bar{\Lambda} &= e^{\Lambda \Delta} \\ \bar{B} &= \Lambda^{-1}(\Lambda - I)\tilde{B} \\ \bar{C} &= \tilde{C} \\ \bar{D} &= D \end{align}\]

In practice,

  • $\Delta \in \mathbb{R}^P$: vector of learnable timescale parameters
  • $D$: restrict to be diagonal

S5 Learnable parameters

  • $\tilde{B} \in \mathbb{C}^{P \times H}$
  • $\tilde{C} \in \mathbb{C}^{H \times P}$
  • $\text{diag}(D) \in \mathbb{R}^H$
  • $\text{diag}(\Lambda) \in \mathbb{C}^P$
  • $\Delta \in \mathbb{R}^P$

S5 Layer

  • 1 MIMO SSM
    • Latent state size $P$
    • Input and output dimension $H$
    • Map input $u_{1:L} \in \mathbb{R}^{L \times H}$ to output $y_{1:L} \in \mathbb{R}^{L \times H}$
  • Then, a nonlinear activation function is applied to produce $u’_{1:L} \in \mathbb{R}^{L \times H}$
  • No mixing layer required, since these features are already mixed
  • S5 vs S4: S4’s $H$ independent SSM can be viewed as a block-diagonal system

s4-s5-internal

Internal structure of a discretized S4 layer vs S5 layer. S4 layer: single block-diagonal SSM with a latent state of size $HN$, followed by a nonlinearity and mixing layer to mix the independent features; S5 layer: dense linear SSM with latent size $P \ll HN$.

Initialization

Initializing S5 with HiPPO-like matrices may also work well in the MIMO setting.

The empirical results show that diagonalizing the HiPPO-N matrix leads to good performance.

[TBD: Appendix E]

Irregularly Spaced Time Series

Parallel scans and the continuous-time parameterization also allow for efficient handling of irregularly sampled time series. This can be achieved by supplying a different $\bar{A}_k$ matrix at each step.

In contrast, the convolution of the S4 layer requires a time invariant system and regularly spaced observations.

Relationship Between S4 and S5

[TBD]

Experiments

Long Range Arena

lra

LRA benchmark. The best Mega model retains the transformer’s $O(L^2)$ complexity.

Speech Classification

sc

Test accuracy on 35-way Speech Commands classification task. Training examples are one-second 16kHz audio waveforms.

Irregularly Spaced Time Series

Pendulum regression requires model to handle observations received at irregular intervals.

Input

  • A sequence of $L=50$ images,
  • Each 24 $\times$ 24 pixels in size
  • Each has been corrupted with a correlated noise process
  • Image sampled at irregular intervals from a continuous trajectory of duration $T = 100$
  • The velocity is unobserved

Target

  • Sine and cosine of the angle of the pendulum

pr-demo

Illustration of the pendulum regression example.

pr-result

Regression MSE $\times 10^{−3}$ (mean +/- std) and relative application speed on the pendulum regression task on a held-out test set. Results for CRU (our run) and S5 are across twenty seeds.