This post is mainly based on

Distribution shift can cause failure in machine learning systems. Consider a dataset shift where the joint distribution is different between training and testing dataset:

\[P_{\operatorname{train}}(y,x) \not= P_{\operatorname{test}}(y,x)\]

Types of Distribution Shifts

Decomposing the joint into marginals, it can be observed that there are several types of distribution shift:

  • Covariate Shift:
    • Model: $P(X, Y) = P(Y | X)P(X)$
    • $P(Y | X)$ is fixed
    • $P(X)$ changes
  • Label Shift (prior shift):
    • Model: $P(X, Y) = P(X | Y)P(Y)$
    • $P(X | Y)$ is fixed
    • $P(Y)$ changes
  • Concept Shift:
    • Model: $P(X, Y) = P(Y | X)P(X)$
    • $P(Y | X)$ changes
    • $P(X)$ is fixed

Note that $P(Y | X)$ implies a discriminative model and $P(X | Y)$ implies a generative model. More details can be found here.

Dataset Shift can be the result of

  • Sample selection bias
  • Non-stationary environments

Detecting Distribution Shift

I am not able to find a good review article or survey on this topic. It appears that different fields requires vastly different ways to measure distribution shift. Below is a summary of different methods based on my reading:

  • By statistics on Labels
    • Kolmogorov-Smirnov Statistic
    • Drift Detection Method (DDM)
  • By statistics on Features
    • Kolmogorov-Smirnov statistic (questionable on high dimensional data due to too many false positive)
    • Population Stability Index (PSI)
    • Time-series drift
    • Novelty and Outlier Detection

Practical considerations for implementing the distribution shift detector includes:

  • Sudden Shifts vs Gradual Shifts
  • Reference distribution size (does it cover enough data / outliers)
  • Cumulative statistics vs Sliding Window statistics

Drift Detection Method (DDM)

DDM is an online detection method for classification problem: given a sequence of prediction of length $i$. Let $p_i$ be the error rate $\frac{\sum_i \mathbb{1}(\hat{y_i} \not= y_i)}{i}$. Assume the error event follows a binomial distribution. Given $p_i$, we can estimate the variance as $s_i = \sqrt{p_i(1-p_i)/i}$. Compute $p_i + s_i$ and track the minimum of this quantity as $p_{min} + s_{min}$.

For a sufficiently large sample, the binomial distribution can be approximated by a normal. Construct the $1-\frac{\alpha}{2}$ confidence interval under failure probability $p$.

  • With 95% confidence, detect distribution shift when $p_i + s_i \geq p_{min} + 2s_{min}$
  • With 99% confidence, detect distribution shift when $p_i + s_i \geq p_{min} + 3s_{min}$

Population Stability Index (PSI)

Consider $x$ is a categorical variable (e.g., credit rating) and $x_i$ be i-th category/bins. Let $B$ be the total number of bins. Let $p_i$ be the percentage of training sample falling inside $x_i$ and let $q_i$ be the percentage of testing sample falling inside $x_i$. The PSI is computed as:

\[\operatorname{PSI} = \sum_{i=1}^B (p_i−q_i)(\log p_i − \log q_i)\]

The PSI can be written as sum of KL divergence in both ways:

\[\begin{align} \operatorname{PSI} &= \sum_{i=1}^B (p_i−q_i)(\log p_i − \log q_i) \\ &= \sum (p_i-q_i)\log \frac{p_i}{q_i} \\ &= \sum p_i \log \frac{p_i}{q_i} + \sum q_i \log \frac{q_i}{p_i} \\ &= \operatorname{KL}(p \| q) + \operatorname{KL}(q \| p) \end{align}\]

This make it obvious why PSI can measure distribution shift in label.

There is a $\chi^2$ test for the null hypothesis $H_0: p_i^* = q_i \forall i=1,…,B$, interested reader please refer to the book listed above.

Time-series drift

The goal is to detect concept drift in a online learning fashion. Consider $k$ samples from a joint distribution over a time interval $[t1, tk]$:

\[\mathcal{S} = \{D_{t_1} , ..., D_{t_k} \} \sim F_{[t_1,t_k]}(X, y)\]

where $ D_{t_i} = (X_{t_i} , y_{t_i} )$. Concept shift occurred at time $k+1$:

\[F_{[t_1, t_k]}(X, y) \not= F_{[t_{k+1}, \infty]} (X, y)\]

A model $f_m$ is trained on existing data $D_1, …, D_m$. The residual / error is computed as \(\operatorname{Error}_m = \mathbb{E}[y - f_m(X)]\). Then a z score $z_m$ is computed on \(\operatorname{Error}_m\). If $z_m$ deviate from previous ${z_1, …, z_{m-1} }$, a concept shift is detected. (They do not specify which test is used on $z_m$, I image it could be a $\chi^2$ test).