How to Detect Model Drift in Production

Model drift is one of the most common production ML failure modes — and one of the most underestimated in interviews. A model that scores 94% AUC in offline evaluation can silently degrade to 81% months later with no code change. This guide covers how to detect drift early, distinguish its types, and design monitoring systems that catch it before users do.

What Interviewers Are Testing

  • Do you understand the difference between data drift, concept drift, and prediction drift?
  • Can you design a monitoring system that catches drift without requiring labeled data in real time?
  • Do you know which statistical tests to apply and their practical limitations at scale?
  • Can you describe a retraining pipeline triggered by drift signals?

Types of Model Drift

Data Drift (Covariate Shift)

The distribution of input features P(X) changes, but the relationship P(Y|X) stays the same. Example: a fraud detection model trained on desktop transactions starts seeing mostly mobile transactions. Feature distributions shift; the decision boundary is still valid if generalization holds.

Concept Drift (Label Drift)

The relationship P(Y|X) changes — the meaning of the label shifts. Example: what counts as “spam” evolves as spammers adapt to your classifier. The same features now map to different labels. This is the hardest to detect because ground truth labels arrive late.

Prediction Drift

The distribution of model outputs P(Ŷ) shifts without a known cause. This is the easiest to detect (no labels needed) and often signals upstream data or concept drift.

Label Drift (Prior Probability Shift)

The marginal distribution of labels P(Y) changes. Example: a customer churn model sees 5% churn rate during COVID vs 15% post-COVID. The model score distribution will shift even if P(Y|X) is stable.

Detection Methods

Statistical Tests for Feature Drift

import numpy as np
from scipy import stats
from scipy.stats import ks_2samp, chi2_contingency

def detect_numerical_drift(reference_data, production_data, feature_name, threshold=0.05):
    """Kolmogorov-Smirnov test for numerical features."""
    stat, p_value = ks_2samp(reference_data, production_data)
    is_drifted = p_value < threshold
    return {
        'feature': feature_name,
        'ks_statistic': stat,
        'p_value': p_value,
        'drifted': is_drifted
    }

def detect_categorical_drift(reference_counts, production_counts, feature_name, threshold=0.05):
    """Chi-squared test for categorical features."""
    # Align categories
    all_cats = set(reference_counts.keys()) | set(production_counts.keys())
    ref_vals = [reference_counts.get(c, 0) for c in all_cats]
    prod_vals = [production_counts.get(c, 0) for c in all_cats]

    chi2, p_value, _, _ = chi2_contingency([ref_vals, prod_vals])
    return {
        'feature': feature_name,
        'chi2': chi2,
        'p_value': p_value,
        'drifted': p_value < threshold
    }

def population_stability_index(reference, production, bins=10):
    """PSI measures how much a distribution has shifted.
    PSI < 0.1: no significant change
    0.1 <= PSI = 0.2: significant change, investigate
    """
    ref_hist, bin_edges = np.histogram(reference, bins=bins)
    prod_hist, _ = np.histogram(production, bins=bin_edges)

    # Add small epsilon to avoid log(0)
    eps = 1e-10
    ref_pct = (ref_hist + eps) / len(reference)
    prod_pct = (prod_hist + eps) / len(production)

    psi = np.sum((prod_pct - ref_pct) * np.log(prod_pct / ref_pct))
    return psi

Prediction Distribution Monitoring

Monitor the distribution of model output scores over time without needing labels:

from collections import deque
import numpy as np

class PredictionDriftMonitor:
    def __init__(self, reference_predictions, window_size=1000, alert_threshold=0.1):
        self.reference = np.array(reference_predictions)
        self.window = deque(maxlen=window_size)
        self.alert_threshold = alert_threshold

    def add_prediction(self, score):
        self.window.append(score)

    def check_drift(self):
        if len(self.window) = self.alert_threshold,
            'mean_shift': production.mean() - self.reference.mean(),
            'window_size': len(self.window)
        }

    def get_score_bucket_distribution(self):
        """Track how prediction buckets shift over time."""
        buckets = {'0.8': 0}
        for score in self.window:
            if score < 0.2: buckets['<0.2'] += 1
            elif score < 0.4: buckets['0.2-0.4'] += 1
            elif score < 0.6: buckets['0.4-0.6'] += 1
            elif score 0.8'] += 1
        total = len(self.window)
        return {k: v/total for k, v in buckets.items()}

Performance-Based Monitoring (When Labels Are Available)

For tasks where ground truth arrives with a delay (e.g., credit default after 30 days):

class DelayedLabelMonitor:
    def __init__(self, window_days=7, baseline_auc=0.94):
        self.predictions = []  # (timestamp, score, label_received_at)
        self.baseline_auc = baseline_auc
        self.alert_threshold = 0.05  # 5 percentage point drop triggers alert

    def log_prediction(self, score, timestamp):
        self.predictions.append({'score': score, 'timestamp': timestamp, 'label': None})

    def update_label(self, prediction_idx, label):
        self.predictions[prediction_idx]['label'] = label

    def compute_recent_auc(self, days=7):
        from sklearn.metrics import roc_auc_score
        from datetime import datetime, timedelta

        cutoff = datetime.now() - timedelta(days=days)
        labeled = [(p['score'], p['label'])
                   for p in self.predictions
                   if p['label'] is not None and p['timestamp'] > cutoff]

        if len(labeled)  self.alert_threshold
        }

Monitoring Architecture

A production drift detection system has three layers:

Layer 1 — Real-time stream monitoring:

  • Log every (input features, prediction score, timestamp) to Kafka
  • Flink/Spark Streaming computes rolling statistics every N minutes
  • Alert on prediction distribution shift (PSI, KS test) without labels

Layer 2 — Daily batch analysis:

  • Compare feature distributions: today’s batch vs. reference window (training data or last 30 days)
  • Run per-feature drift tests; rank features by PSI to identify which features are shifting
  • Compute business metrics (click-through rate, conversion) as proxy labels

Layer 3 — Delayed label evaluation:

  • When ground truth labels arrive, compute offline metrics on the recent prediction window
  • Compare to baseline performance; trigger retraining if degradation exceeds threshold

Retraining Triggers

Trigger Type When to use
Scheduled retraining Time-based Stable domains with predictable drift (e.g., monthly seasonality)
PSI threshold exceeded Data drift Feature distributions shift beyond 0.2 PSI
Performance threshold Concept drift AUC drops >5% from baseline on labeled window
Business metric drop Proxy signal CTR/conversion falls before labels are available

Practical Pitfalls

Multiple testing problem: With 100 features, 5 will appear to drift at α=0.05 by chance. Use Bonferroni correction or control FDR with Benjamini-Hochberg.

Feedback loops: Your model’s predictions influence future data. A recommendation model shapes what users see, which shapes the next training set. Standard drift detection misses this — track upstream feature correlations.

Gradual vs. sudden drift: Statistical tests have different power for gradual drift. CUSUM (Cumulative Sum Control Chart) and ADWIN (Adaptive Windowing) are designed for sequential detection.

Depth Levels

Junior: Explain the types of drift and how you’d monitor prediction score distribution.

Senior: Design a monitoring pipeline from Kafka ingest to alert. Discuss delayed label scenarios. Choose appropriate statistical tests per feature type.

Staff: Address feedback loops, multiple testing correction, and the trade-off between retraining cost and drift tolerance. Design a champion/challenger infrastructure for safe model rollout.

Related ML Topics

See also: MLOps Interview Questions — drift detection is the monitoring component of the MLOps pipeline; drift triggers feed the automated retraining CI/CD loop.

Scroll to Top