This post is mainly based on

Transformer is an architecture that is built on Attention.

Attention

Recall an Attention layer maps an input sequence $x_1, …, x_n$ to a contextualized embedding $z_1, …, z_n$

We need to distinguish between Self-Attention, Cross-Attention, and Causal-Attention:

  • Self-Attention
    • Only consider the target sequence (length $L$)
    • e.g., language model
  • Cross-Attention
    • Consider both the target sequence (length $L$) and the source sequence (length $M$)
    • e.g., story completion
    • The source sequence is the prompt and the target sequence is the generated text
    • Self-Attention can be viewed as a special case of Cross-Attention where the source sequence is the target sequence itself
  • Causal-Attention
    • The model can only observe elements in the sequence before current element
    • e.g., time series prediction
    • Implemented by masking unobserved elements

A forward pass on attention layer consists of a series of matrix multiplications.

Cross-Attention

Consider the input data:

  • Source sequence: $y_1, …, y_M \in \mathbb{R}^d$, $d$-dimensional vectors
  • Target sequence: $x_1, …, x_L \in \mathbb{R}^d$, $d$-dimensional vectors
  • Write them in a matrices $Y \in \mathbb{R}^{M \times d}, X \in \mathbb{R}^{L \times d}$

The model parameters:

  • Query projection matrix: $W_q \in \mathbb{R}^{d \times d’}$
  • Key projection matrix: $W_k \in \mathbb{R}^{d \times d’}$
  • Value projection matrix: $W_v \in \mathbb{R}^{d \times d}$
  • $d’$ is the dimension of intermediate representation

Project the input $X,Y$ onto some latent space:

  • Query: $ Q = X W_q \in \mathbb{R}^{L \times d’} $
  • Key: $ K = Y W_k \in \mathbb{R}^{M \times d’} $
  • Value: $ V = Y W_v \in \mathbb{R}^{M \times d} $

Compute cross-attention matrix $A_{norm}$:

\[A_{norm} = \operatorname{softmax}\left( \frac{Q K^\top}{ \sqrt{ d' } } \right)\]

Alternatively, we can write $A_{norm}$ as:

\[A = \exp \left( \frac{Q K^\top}{ \sqrt{ d' } } \right) \in \mathbb{R}^{L \times M}\] \[\left[ A_{norm} \right]_{ij} = \frac{A_{ij}}{ \sum_{m=1}^M A_{im} }\]

where $\exp$ is element wise exponentiation. The reason why we want to write $A_{norm}$ as $A$ will be discussed in the kernel function section below.

Compute output matrix $X’$:

\[X' = A_{norm} V \in \mathbb{R}^{L \times d}\]

Self-Attention

Setting the source sequence as the target sequence ($Y=X$) yields Self-Attention:

  • $X \in \mathbb{R}^{L \times d}$

Compute the intermediate representations:

  • Query: $ Q = X W_q \in \mathbb{R}^{L \times d’} $
  • Key: $ K = X W_k \in \mathbb{R}^{L \times d’} $
  • Value: $ V = X W_v \in \mathbb{R}^{L \times d} $

The cross-attention matrix reduce to a self-attention matrix. …

Causal-Attention

Causal-Attention is achieve by causal mask, which will be discussed below.

Causal-Attention is not mutually exclusive with Cross-Attention or Self-Attention. For example, the Transformer decoder contains two attention layers:

  • First layer
    • Self-Attention + Causal Attention
    • Auto-regressive: each token attend to the translated sentence, but can only attend to previously generated words
  • Second layer
    • Cross-attention
    • Each token in the translated sentence can attend to any tokens in the source sentence

One can imagine designing an architecture using “Cross-attention + Causal Attention”, i.e., the output sequence can only attend to current step in the input sequence.

Implementation

The above 3 types of Attentions can be implemented by a single function. Check Harvard NLP’s implementation for details.

Analysis

Kernel Function

\[K: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}\]

As discussed in the Attention post, a kernel function can be a dot product (linear kernel):

\[K(x,y) = x^\top y\]

The Transformer uses a modified Softmax kernel:

\[K(x,y) = \exp \left( \frac{x^\top y}{\sqrt{d}} \right)\]

The story I heard is the researcher first tried linear kernel on NMT problem, but it did not work well. The reason they suspect is that attention in NLP should be sparse: the target token is strongly correlated to some tokens and has zero correlation to others. This requires a more “sharp” kernel.

A natural way to amplify a value is to take exponential and the softmax function is commonly used to approximate the $\operatorname{argmax}$ function. However, a naive softmax kernel creates stability problem: considering $d=64$, a dot product $x^\top y$ could produce a large scalar and taking an exponential leads to a numerical overflow.

The modified Softmax kernel normalize the dot product by $\sqrt{d}$. The constant $\sqrt{d}$ is not just a heuristic design choice: at the beginning of training, $W$ is initialized as $W \overset{inp}{\sim} N(0,1)$. Assume input is normally distributed, then $Q, K \overset{inp}{\sim} N(0,1)$, hence each $x,y \overset{inp}{\sim} N(0,1)$.

The expectation $\mathbb{E}[x^\top y] = \mathbb{E}[ \sum_{i=1}^d X_i ] $ where $X_i \overset{inp}{\sim} N(0,1)$. By concentration inequality or central limit theorem, $P( x^\top y \geq \sqrt{d} )$ is very small.

Other than the softmax kernel, one can use the ReLU kernel:

\[K(x,y) = \operatorname{ReLU}(x)^\top \operatorname{ReLU}(y)\]

The ReLU kernel is fast, but less accurate than the softmax kernel in empirical experiments.

Complexity

A major problem of the Transformer is its computation does not scale linearly on sequence length $L$.

The computation cost of the attention matrix $Q K^\top$ is $O(L^2d’)$, where $ Q \in \mathbb{R}^{L \times d’} $ and $ K \in \mathbb{R}^{M \times d’} $. For extremely long sequence (e.g., music data with length in thousands, or DNA sequencing data with length in millions), the quadratic complexity is not acceptable. See Table 1 of the Music Transformer paper for details.

Recent works explores the possibility of low rank approximation of the softmax kernel function:

\[A = \exp(QK^\top / \sqrt{d'}) \approx \mathbb{E}[ Q' \cdot (K')^\top ]\]

The advantage is instead of first computing $A = QK^\top$ then $X’=AV$, we can first compute $B=(K’)^\top V$, then $X’=Q’B$. This result in a linear complexity $O(Ldd’)$.

For details, readers can check the Performer paper.

Masking

Reader may need to distinguish the attention mask from the input mask in Masked Language Models (MLM).

  • Attention mask: set specific values of the attention matrix to zero
  • Input mask in MLM: set certain tokens in the input sequence to a pre-assigned [MASK] token
    • The encoder’s attention matrix is not affected.

No Mask

As the name suggested, the attention matrix is passed as it is, implying each word can attend to any words in the entire sequence.

Causal Mask

The attention matrix is masked by a lower triangular matrix, implying each word can only attend to words before it.

causal-mask

Positional Encoding

The Attention operation does not contains ordering information of the sequence (i.e., a weighted sum of value matrix $V$, where weights are compute from $Q,K$ using a kernel function). Positional encoding (PE) was used to inject information of relative or absolute position into the input sequence.

A valid PE should: encode a unique vector for each position in the sequence.

Two variants are proposed:

  • Absolute Positional Encoding (APE): inject the positional information into the input sequence
  • Relative Positional Encoding (RPE): inject the positional information into the attention matrix

In music generator task, RPE significantly outperforms APE.

Absolute Positional Encoding (APE)

APE was proposed in the Attention Is All You Need paper. It directly inject the positional information into the input sequence.

Suppose the input embedding is $d$ dimensional. APE create a $d$ dimensional vector and add this vector to the input embedding.

The authors uses sine and cosine functions of different frequencies as APE:

\[\operatorname{PE}(pos, 2i) = \sin(\frac{pos}{10000^{ 2i/d} } )\] \[\operatorname{PE}(pos, 2i+1) = \cos(\frac{pos}{10000^{ 2i/d} } )\]

where $pos$ is the position, and $2i$ / $2i+1$ is the dimension. The wavelengths of sine and cosine function form a geometric progression from $2\pi$ (dimension 0 and 1) to $10000 \cdot 2\pi$ (dimension 510 and 511).

tf-pe

Visualization of positional encoding from Jay Alammar’s blog. y-axis: position 0-9; x-axis: APE vector’s dimension 0-64.

Relative Positional Encoding (RPE)

RPE was proposed and refined in the above two 2018 papers and adapt to the kernel approximation method in the 2021 paper.

Given the attention matrix $A$ where

\[A_{ij} = \exp(\frac{q_i k_j^\top}{ \sqrt{d} })\]

The idea is directly encoder the relative positional information to the attention matrix:

\[\widetilde{A}_{ij} = A_{ij} \cdot f_\theta(i-j) = \exp(\frac{q_i k_j^\top + \sqrt{d}\log f_\theta(i-j)}{ \sqrt{d} }) = \exp(\frac{q_i k_j^\top + R_{ij}}{ \sqrt{d} })\]

where $f_\theta: \mathbb{R} \rightarrow \mathbb{R}$ is a learnable function with parameter $\theta$ and $R_{ij}$ is a Toeplitz matrix which encodes the relative positional information. The reason to express $f_\theta$ in a matrix $R$ is to utilize matrix parallel processing.

In the original paper the RPE also take value of $Q$ into consideration:

\[A_{ij} = \exp(\frac{q_i k_j^\top + q_i (R_{ij})^\top}{ \sqrt{d} })\]

Architecture

A detailed discussion of Encoder and Decoder design of Transformer:

Normalization

Transformer uses layer normalization. The reason is that, batch normalization is inefficient for a batch of variable length sequences. More discussions can be found at [TBD].

Multi-Head Attention

The idea of Multi-Head Attention is that each attention block can focus on one type of relationship. This is analogous to having multiple channels in CNN. By random initialization, attention blocks are likely being optimized to find different patterns after training.

Each of the $h$ attention heads maintains its own $W_q, W_k, W_v$. The Concat layer pools information from different heads together and Linear layer applies a transformation on the combined information.

tf-multi-head-att

Encoder - Decoder

Encoder and decoder contains stacks of $N$ EncoderLayer and DecoderLayer.

arch

The Transformer - model architecture

EncoderLayer

  • Multi-Head Self-Attention
  • Skip connection + Layer Normalization
  • Feed-Forward Networks
  • Skip connection + Layer Normalization

DecoderLayer

  • Multi-Head Causal-Attention
  • Skip connection + Layer Normalization
  • Multi-Head Cross-Attention
  • Skip connection + Layer Normalization
  • Feed-Forward Networks
  • Skip connection + Layer Normalization

Feed-Forward Networks

  • Linear
  • ReLU
  • Dropout
  • Linear

Attention class contains a Dropout parameter, which applies dropout to the attention matrix.

Different configuration of encoder/decoder blocks result in architectures:

  • Transformer: stacked encoder blocks and decoder blocks
  • BERT: stacked encoder blocks
  • GPT: stacked decoder blocks