This post is mainly based on

Strongly recommend readers to go through Online Softmax to understand the scaling factor $l$ and the numerical stability term $m$ first.

Problem with Attention

  • Computing and storing Attention matrix is costly (quadratic complexity)
  • Motivation: longer sequence model generally results in better validation performance
  • Three school solutions
    • Hardware aware training
    • Approximate attention (tradeoff quality for speed)
      • Sparse attention: Sparse Transformer, Reformer. Routing Transformer
      • Low rank approximation of attention: Linformer, Linear Transformer, Performer
      • No wide spread adoption (due to reduce performance and computation overhead)
    • MapReduce (summarize short sequence and combine into long sequence)

Terminology

  • GPU architecture
    • CPU DRAM
      • Speed: 25.6 GB/s (DDR4-3200)
      • Size: >1 TB
      • PCIe 4.0 x16 bandwidth: 32 GB/s
    • GPU DRAM: Dynamic Random-Access Memory
      • GPU’s off-chip main memory
      • Speed: 1.5 TB/s (HBM)
      • Size: 40 GB
    • SRAM: Static Random-Access Memory
      • GPU’s fast on-chip cache memory
      • Speed: 19 TB/s
      • Size: 20 MB
    • Compute: fundamental processing / where compute actually happens
      • Streaming multiprocessors (SMs) or shader cores
      • Tensor cores
  • Tiling
    • A technique used in GPU computing to optimize memory access patterns and increase data locality
    • Partitioning a data set into smaller, contiguous tiles that fit within the (SRAM)
    • Exploit the high bandwidth / low latency SRAM, minimize to access DRAM
  • Kernel fusion
    • A compiler optimization technique used to combine multiple loops or kernel operations into a single loop or kernel
    • Goal: reduce memory access and improve cache utilization
  • Tensor contraction
    • Generalized matrix multiplication (rank-2 tensors)
    • Summing over paired indices in tensors to produce a new tensor with reduced rank

Types of Runtime Costs

  • Compute cost
    • Compare CPU and GPU speed
      • A100: 312 trillion FLOPS
      • Python: 32 million FLOPS
      • The time that Python can perform a single FLOP, an A100 could perform 9.75 million FLOPS
  • Data transfer cost
    • Memory bandwidth cost: moving data from DRAM to SRAM
    • Data transfer costs: moving data from CPU DRAM to GPU DRAM
    • Network costs: moving data from one GPU node to another (for distributed training)
  • Overhead cost
    • All remaining cost
    • Exmaples
      • Time spent in the Python interpreter
      • Time spent in the PyTorch framework
      • Time spent launching CUDA kernels (but not executing them)
  • Example (unary operation, e.g., dropout)
    • Overhead cost: preparing to launch kernels
    • Data transfer cost: Move data from DRAM to SRAM
    • Compute cost: SM access SRAM and perform a tiny bit of computation
    • Data transfer cost: Move data from SRAM to DRAM
  • High memory bandwidth cost for the example above
  • Remedy: Kernel Fusion
    • Instead of moving data to DRAM just to read it back again, we performing several computations at once
    • Example: GELU takes nearly same time as ReLU, desipte have more operations
  • Question: do we even need activation checkpointing?
    • Recompute might be faster than moving data, hence cost less runtime
    • For A100: until we’re doing ~100 unary operator operations, we’ll be spending more time performing memory accesses than actual compute

PyTorch Efficiency

  • PyTorch have many layers of dispatch before getting to actual compute kernel
    • A lot of time needs to be spent on “figuring out what to do”
    • Example: when compute a + b, the following steps need to happen
      • Look up what add dispatches to on a
      • Determine many attributes of the tensor (dtype, device, autograd?) to determine which kernel to call
      • Actually launch the kernel
    • Fundamentally, overhead cost comes from the flexibility of being able to do something different at each step
  • Pro: Flexible in eager mode
  • Con: For scientific computing / tiny tensors, PyTorch could be incredibly slow compared to C++
  • Remedy
    • Asynchronously execution to reduce overhead cost
    • While running a CUDA kernel, PyTorch can continue and queue up more CUDA kernels behind it

Data Movement is All You Need

  • Training become memory-bound, rather than compute-bound
    • Megatron report achieving only 30% of peak GPU FLOPS
    • Existing implementations do not efficiently utilize GPUs
  • Findings
    • Key bottleneck when training transformers is data movement
    • BERT training
      • Memory-bound operations: 37% runtime
      • Tensor contractions: 99% FLOPS, 61% runtime
  • New approach
    • Reduce data movement by up to 22.91%
    • 1.3x performance improvement on BERT encoder layer

Dataflow Analysis

  • Tools
    • Data-Centric (DaCe) parallel programming framework
    • Stateful Dataflow multiGraph (SDFG)
  • Breakdown of FLOP & runtime
    • PyTorch implementation of BERT encoder layer
    • Normalization: 250x slower than Tensor contraction/MatMul (in FLOP)
    • Element-wise: 700x slower than Tensor contraction/MatMul (in FLOP)

flop-vs-runtime

Custom Kernels

  • More efficient than NVFuser and XLA
  • Limitations: CUDA code is hard to write

kernel-speed

FlashAttention

  • Hardware-Aware Optimization
  • Goal: optimize Read/Write to different levels of memory to reduce GPU memory IOs
  • Technique
    • Tiling: load block by block from HBM to SRAM to compute attention
    • Recomputation: don’t store attention matrix, recompute it in backward pass
  • Results on Exact Attention
    • Benchmark: vanilla PyTorch implementation
    • 2-4x faster
    • 10-20x more memory efficient (linear complexity)

fa-speed

Forward + backward runtime of standard attention and FlashAttention for GPT-2 medium. Despite higher FLOP, fewer HBM access greatly improves runtime.

gpu-speed attention-speed

FlashAttention does not read and write the attention matrix to HBM, resulting in an 7.6x speedup on the attention computation.

Standard Attention Implementation

Matrix Computation

  • $Q,K,V \in \mathbb{R}^{N \times d}$
  • Attention Score Matrix: $S=QK^T \in \mathbb{R}^{N \times N}$
  • Attention Matrix: $P=\text{softmax}(S) \in \mathbb{R}^{N \times N}$
  • Output Matrix: $O = PV \in \mathbb{R}^{N \times d}$
  • Where
    • $N$: sequence length
    • $d$: head dimension
    • softmax() on $S \in \mathbb{R}^{N \times N}$ is applied row-wise

Algorithm 0

  1. Requires: Matrices $Q, K, V$ in HBM
  2. Load $Q, K$ by blocks from HBM, compute $S = QK^T$, write $S$ to HBM.
  3. Read $S$ from HBM, compute $P = \text{softmax}(S)$, write $P$ to HBM.
  4. Load $P, V$ by blocks from HBM, compute $O = PV$, write $O$ to HBM.
  5. Return $O$.

Analysis

  • Problems of Algorithm 0
    • Memory bounded: storing attention matrix $P$ requires $O(N^2)$
    • I/O bounded: element-wise operation on $S$ or $P$ (e.g., masking, dropout)
  • How to reduce Memory requirement and I/O?
    • Idea: perform most computation in block in SRAM
      • Avoid storing intermediate matrices
      • Reduce R/W to HBM
    • Challenge 1: Compute softmax without access to row of $S$ matrix
    • Challenge 2: Backward without the large attention matrix from forward
  • Solution
    • Solution to Challenge 1: Tiling
    • Solution to Challenge 2: Recomputation
    • FlashAttention implement fused CUDA kernel, instead of using pytorch suboptimal computation graph

Tiling (Section 3.1, Appendix B.1, B.3)

Difference between Memory Efficient Self-attention (MESA) and FlashAttention (FA)

  • Query-chucks are in outer loop in MESA, but they are in inner loop in FA
  • FA have block size defined in algorithm, but MESA requires user to define chuck size
  • FA is implemented in CUDA, MESA is implemented in JAX

Algorithm 1 (forward pass only)

  • Initialization
    • Calculate key-value block size $B_c = \lceil \frac{M}{4d} \rceil$
    • Calculate query block size $B_r = \min( \lceil \frac{M}{4d} \rceil, d )$
    • Calculate key-value block number $T_c = \lceil \frac{N}{B_c} \rceil$
    • Calculate query block number $T_r = \lceil \frac{N}{B_r} \rceil$
    • Divide $K, V$ into $T_c$ blocks of dim $B_c \times d$
    • Divide $Q$ into $T_r$ blocks of dim $B_r \times d$
    • Divide $O$ (output embedding) into $T_r$ blocks of dim $B_r \times d$
    • Divide $l$ (softmax denominator / scaling factor) into $T_r$ blocks of dim $B_r$
    • Divide $m$ (numerical stability term) into $T_r$ blocks of dim $B_r$
  • Outer loop: Load $K_j$ and $V_j$ into SRAM
  • Inner loop
    • Load $Q_i, O_i, l_i, m_i$ into SRAM
    • Compute Online Softmax
    • $O_i$ is updated multiple times as we loop though $j$
      • For $O_i$ computed from $j-1$ outer loop, adjust it with scaling factor $l$
      • Add new attention score adjusted $V_j$ into this scaling factor adjusted $O_i$
      • Wite the new $O_i$ to HBM (to be updated again in $j+1$ outer loop)
    • Update $l_i, m_i$ and wite to HBM
  • Return $O$

Algorithm 1 compute attention $O = \text{softmax}(QK^T)V$ with $O(N^2d)$ FLOPs and requires $O(N)$ additional memory beyond inputs and output.

tiling

Tiling: partitioning a data set into smaller, contiguous tiles that fit within the (SRAM).

Recomputation (Section 3.1, Appendix B.2, B.4)

  • Goal: avoid storing $O(N^2)$ attention matrix for the backward pass
  • Can recompute the attention matrix $S, P$ easily in the backward pass from $O, l, m, Q, K, V$
  • Can be seen as a form of selective gradient checkpointing (?)
  • Tradeoff more computation for reduced memory. Even with more FLOPs, FlashAttention achieved speed up in backward pass due to reduced HBM access.

IO Complexity

  • See paper’s Section 3.2

Experiments

  • Training Speed
    • Outperforms 15% MLPerf 1.1 speed record
    • Speedup on GPT-2: 3x over HuggingFace and 1.8x over Megatron
  • Quality
    • Train 4k-GPT2 faster than Megatron 1k-GPT2, achieve 0.7 lower perplexity
    • 6.4 points of lift on two long-document classification task: MIMIC-III and ECtHR
    • Beat SOTA transformer on Path-X and all SOTA sequence model on Path-256
  • Memory
    • Memory footprint of FlashAttention scales linearly with seq. length

vs-seq-length

Left: runtime of forward pass + backward pass. Right: attention memory usage.