Distribution Shift
This post is mainly based on
Distribution shift can cause failure in machine learning systems. Consider a dataset shift where the joint distribution is different between training and testing dataset:
\[P_{\operatorname{train}}(y,x) \not= P_{\operatorname{test}}(y,x)\]Types of Distribution Shifts
Decomposing the joint into marginals, it can be observed that there are several types of distribution shift:
- Covariate Shift:
- Model: $P(X, Y) = P(Y | X)P(X)$
- $P(Y | X)$ is fixed
- $P(X)$ changes
- Label Shift (prior shift):
- Model: $P(X, Y) = P(X | Y)P(Y)$
- $P(X | Y)$ is fixed
- $P(Y)$ changes
- Concept Shift:
- Model: $P(X, Y) = P(Y | X)P(X)$
- $P(Y | X)$ changes
- $P(X)$ is fixed
Note that $P(Y | X)$ implies a discriminative model and $P(X | Y)$ implies a generative model. More details can be found here.
Dataset Shift can be the result of
- Sample selection bias
- Non-stationary environments
Detecting Distribution Shift
I am not able to find a good review article or survey on this topic. It appears that different fields requires vastly different ways to measure distribution shift. Below is a summary of different methods based on my reading:
- By statistics on Labels
- Kolmogorov-Smirnov Statistic
- Drift Detection Method (DDM)
- By statistics on Features
- Kolmogorov-Smirnov statistic (questionable on high dimensional data due to too many false positive)
- Population Stability Index (PSI)
- Time-series drift
- Novelty and Outlier Detection
Practical considerations for implementing the distribution shift detector includes:
- Sudden Shifts vs Gradual Shifts
- Reference distribution size (does it cover enough data / outliers)
- Cumulative statistics vs Sliding Window statistics
Drift Detection Method (DDM)
DDM is an online detection method for classification problem: given a sequence of prediction of length $i$. Let $p_i$ be the error rate $\frac{\sum_i \mathbb{1}(\hat{y_i} \not= y_i)}{i}$. Assume the error event follows a binomial distribution. Given $p_i$, we can estimate the variance as $s_i = \sqrt{p_i(1-p_i)/i}$. Compute $p_i + s_i$ and track the minimum of this quantity as $p_{min} + s_{min}$.
For a sufficiently large sample, the binomial distribution can be approximated by a normal. Construct the $1-\frac{\alpha}{2}$ confidence interval under failure probability $p$.
- With 95% confidence, detect distribution shift when $p_i + s_i \geq p_{min} + 2s_{min}$
- With 99% confidence, detect distribution shift when $p_i + s_i \geq p_{min} + 3s_{min}$
Population Stability Index (PSI)
Consider $x$ is a categorical variable (e.g., credit rating) and $x_i$ be i-th category/bins. Let $B$ be the total number of bins. Let $p_i$ be the percentage of training sample falling inside $x_i$ and let $q_i$ be the percentage of testing sample falling inside $x_i$. The PSI is computed as:
\[\operatorname{PSI} = \sum_{i=1}^B (p_i−q_i)(\log p_i − \log q_i)\]The PSI can be written as sum of KL divergence in both ways:
\[\begin{align} \operatorname{PSI} &= \sum_{i=1}^B (p_i−q_i)(\log p_i − \log q_i) \\ &= \sum (p_i-q_i)\log \frac{p_i}{q_i} \\ &= \sum p_i \log \frac{p_i}{q_i} + \sum q_i \log \frac{q_i}{p_i} \\ &= \operatorname{KL}(p \| q) + \operatorname{KL}(q \| p) \end{align}\]This make it obvious why PSI can measure distribution shift in label.
There is a $\chi^2$ test for the null hypothesis $H_0: p_i^* = q_i \forall i=1,…,B$, interested reader please refer to the book listed above.
Time-series drift
The goal is to detect concept drift in a online learning fashion. Consider $k$ samples from a joint distribution over a time interval $[t1, tk]$:
\[\mathcal{S} = \{D_{t_1} , ..., D_{t_k} \} \sim F_{[t_1,t_k]}(X, y)\]where $ D_{t_i} = (X_{t_i} , y_{t_i} )$. Concept shift occurred at time $k+1$:
\[F_{[t_1, t_k]}(X, y) \not= F_{[t_{k+1}, \infty]} (X, y)\]A model $f_m$ is trained on existing data $D_1, …, D_m$. The residual / error is computed as \(\operatorname{Error}_m = \mathbb{E}[y - f_m(X)]\). Then a z score $z_m$ is computed on \(\operatorname{Error}_m\). If $z_m$ deviate from previous ${z_1, …, z_{m-1} }$, a concept shift is detected. (They do not specify which test is used on $z_m$, I image it could be a $\chi^2$ test).