This post is mainly based on

Online Softmax

  • Modern GPU are mostly I/O bounded, not compute bounded
  • Parallel Online Softmax achieves 1.3x speed up by reducing memory I/O
  • Fusing Softmax + TopK achieves 5x speed up by further reducing memory I/O

Naive Softmax

Softmax is the default method of converting a score vector into a probability vector

\[y_i = \frac{ e^{x_i} }{ \sum e^{x_j} }\]

where $x,y \in \mathbb{R}^V$

Let $d_V = \sum e^{x_j} $, a naive softmax algorithm runs 2 loops

  • Compute the $d_V$ in the first loop: $d_j \leftarrow d_{j-1} + e^{x_j}$
  • Compute $y_i$ in the second loop: $y_i \leftarrow e^{x_i} / d_V$

This induce 2 $O(V)$ read and 1 $O(V)$ write.

Safe Softmax

On real hardware, the large sum $d_V$ can overflow. Therefore, a safe softmax can be computed as follow

Let $m_V = \max x_i$

\[y_i = \frac{ e^{x_i} }{ \sum e^{x_j} } \cdot \frac{e^{-m_V}}{e^{-m_V}}= \frac{ e^{x_i - m_V} }{ \sum e^{x_j - m_V} }\]

This ensures that all elements of the running sum: $e^{x_j - m_V} < 1$. Therefore, the full precision FP32 $d_j$ is guaranteed to not overflow for sum of $V < 1.7 \cdot 10^{37}$ elements.

All major DL frameworks are using this safe version for the Softmax computation. However, the safe softmax requires additional memory access:

  • Compute $m_V$ in the first loop: $m_k \leftarrow \max(m_{k−1}, x_k)$
  • Compute $d_V$ in the second loop: $d_j \leftarrow d_{j-1} + e^{x_j - m_V}$
  • Compute $y_i$ in the third loop: $y_i \leftarrow e^{x_i - m_V} / d_V$

This induce 3 $O(V)$ read and 1 $O(V)$ write.

Online Softmax

The Online softmax seek to reduce the 3 $O(V)$ read to 2 $O(V)$ read.

  • The problem is computing $d_V$ requires $m_V$: $d_j \leftarrow d_{j-1} + e^{x_j - m_V}$

The online version compute $d_V$ and $m_V$ in 1 loop:

\[m_j \leftarrow \max(m_{j-1}, x_j)\] \[d_j \leftarrow d_{j-1} \times e^{m_{j-1} - m_j} + e^{x_j - m_j}\]

Essentially, the online normalizer trade memory access with additional compute: adjusting the $d_{j-1}$ term with scaling factor on the fly.

Parallel Online Softmax

Define a generalized version of the online normalizer calculation:

\[\begin{bmatrix} m_V \\ d_V \end{bmatrix} = \begin{bmatrix} x_1 \\ 1 \end{bmatrix} \oplus \begin{bmatrix} x_2 \\ 1 \end{bmatrix} \oplus ... \oplus \begin{bmatrix} x_V \\ 1 \end{bmatrix}\]

where the binary operator $\oplus: \mathbb{R}^2 \times \mathbb{R}^2 \rightarrow \mathbb{R}^2$ is defined as:

\[\begin{bmatrix} m_i \\ d_i \end{bmatrix} \oplus \begin{bmatrix} m_j \\ d_j \end{bmatrix} = \begin{bmatrix} m_{ij} \\ d_i \times e^{m_i - m_{ij}} + d_j \times e^{m_j - m_{ij}} \end{bmatrix}\]

where $m_{ij} = \max(m_i, m_j)$

Analysis

  • Applying $\oplus$ sequentially from left to right is equivalent to running Online Softmax
  • $\oplus$ is associative
  • $\oplus$ is commutative

$\oplus$ is associative, therefore GPU can executes pairwise operations in parallel batches:

  • First apply pairwise $\oplus$ on element: (1,2), (3,4), …, in parallel batches
  • First apply pairwise $\oplus$ on element: (12, 34), (56, 78) …, in parallel batches
  • Map-reduce paradigm: time complexity $O(\log V)$

$\oplus$ is commutative, therefore GPU can execute operations regardless of the order in which input elements are processed:

  • Different processors may complete tasks at varying times
  • GPU do not need to enforce strict synchronization

Fused operation

  • The goal of this section is demonstrate the how memory I/O can impact GPU performance
  • Algorithm:
    • Initialize
      • $u \in \mathbb{R}^{K+1}$: topK value holder
      • $p \in \mathbb{Z}^{K+1}$: topK index holder
    • In the Online Softmax loop
      • Compute $m_j$ and $d_j$
      • $u_{K+1} \leftarrow x_j$
      • $p_{K+1} \leftarrow j$
      • Sort $u$ and update $p$ according to $u$
    • Return
      • Compute $v$ using vector $u$ and scalar $m_V, d_V$
      • Output $v$ and $p$ (ignore last dummy element)

Analysis

  • Safe Softmax + TopK separately: 4 read and 1 write on all elements
  • The fused algorithm: 1 read on all elements and $K$ write

softmax-topk

Benchmarking softmax and top-k, Tesla V100, fp32, batch size 4000 vectors. Larger batch results in higher performance.

Memory Efficient Self-attention

  • Exact algorithm, not an approximation
  • Time complexity: $O(n^2)$; Memory complexity: $O(\sqrt{n})$
  • For sequence length 16384, the memory overhead of self-attention is reduced by
    • 59x for inference
    • 32x for differentiation

Algorithm

Consider the query of a single token against all keys:

\[s_i = q \cdot k_i\] \[s_i' = e^{s_i}\] \[\text{softmax}(q, k_i) = \frac{s_i'}{\sum_j s_j'}\] \[\text{Attention}(q,k,v) = \sum_i \text{softmax}(q, k_i) \cdot v_i = \frac{\sum_i v_i s_i'}{\sum_j s_j'}\]

where $q, k_i, v_i \in \mathbb{R}^d$

Hence $s_i, s_i’ \in \mathbb{R}$, $\text{softmax}(q, k_i) \in \mathbb{R}$ and $\text{Attention}(q,k,v) \in \mathbb{R}^d$

Note that, compared to softmax, the Attention algorithm compute the weight sum of all values $[ v_1, …, v_n ]$, using softmax of score $[s_1, …, s_n]$ as weight.

Although the author did not mention the Online Softmax’s paper, their update algorithm over problematic $\text{softmax}$ is essentially the same:

  • Initialization
    • $m^* = - \inf$
    • $v^*$ as zero vector
    • $s^* = 0$
  • Loop
    • $m_i = \max(m^*, s_i)$
    • $v^* \leftarrow v^* e^{m^* - m_i} + v_i e^{s_i - m_i}$
    • $s^* \leftarrow s^* e^{m^* - m_i} + e^{s_i - m_i}$
  • Result
    • output = $ v^* / s^* $

Note that the $m^*$ step is essential for numerical stability in GPU hardware since $\text{score} \geq 89$ will cause overflow in bfloat16 and float32.

JAX Implementation

  • Support for multiple attention heads and memory-efficient differentiation
  • Computation was split into chunk for parallelism

Please refer to paper Figure 1 for the detailed JAX implementation, below is a high level summary:

Split Query into chucks with lax.dynamic_slice()
jax.lax.map() on each query chuck:
    Split Key and Value into chucks with jax.lax.dynamic_slice()
    jax.lax.map() on each key-value chuck:
        Compute element-wise attention scores and softmax for each chuck
        Compute intermediate m*, v*, s* at chuck level
        
    Combine chuck level v*, s* into a final v*, s*
    // this is the embedding output for each query token within the query chuck
    Return v* / s*  
    
Return embedding for all query tokens

Note that jax.lax.map() is design is apply a function to each element of an array in parallel. However, if there is memory constraints, jax.lax.map() might behave more like sequential jax.lax.scan().

Experiments

  • Compare memory-efficient attention vs FLAX attention
  • Inference
    • FLAX attention can handle $2^14k$ or 16k length sequence
    • Memory-efficient attention can handle $2^20$ or 1 billion length sequence
    • Not much slow down
  • Training
    • 30% training slow down due to checkpointing
    • No performance degradation on measured by BLEU scores
  • Query chucking
    • Small query chunk size result in poor performance
    • For large chunk sizes, the loss of performance is less significant

jax-inference

Memory and time requirements of self-attention during inference.

chuck-size

Relative runtime of self-attention on sequence length $2^{15}$ using query chunking compared to standard attention.

Limitations

  • Choice of query_chunk_size and key_chunk_size are left to programmer
  • Author claims that they do not observe speed similar to FlashAttention due to TPUs are already optimized for compute FLOPs and memory bandwidth