ML Model Serving System Low-Level Design: Versioned Deployment, A/B Testing, Shadow Mode, and Monitoring

Why Model Serving Is Not Trivial

Deploying a model to production is an engineering discipline separate from training one. A serving system must handle versioned artifacts, route traffic across model variants, validate inputs and outputs, cache predictions efficiently, and detect when a deployed model starts behaving differently from its training baseline — all while keeping latency in the low tens of milliseconds. This post designs those components at the low level.

Model Registry

The registry is the catalog of all trained model artifacts. Each ModelVersion record stores:

  • artifact_path: S3 URI to the serialized model (TorchScript, SavedModel, ONNX, or joblib-serialized sklearn pipeline).
  • framework: pytorch, tensorflow, sklearn, xgboost.
  • input_schema: JSONB describing expected feature names, dtypes, and shapes — used for input validation at serving time.
  • metrics: JSONB with evaluation metrics (AUC, RMSE, precision@k) recorded at training time — the baseline for drift comparison.
  • status: staging | champion | challenger | shadow | retired.

Serving Architecture

The serving stack has three layers:

  1. Model server: TorchServe or TF Serving instances that load model artifacts from S3 on startup and expose a gRPC or HTTP inference endpoint. Each instance serves one model version.
  2. Routing proxy: A lightweight service (Go or Python FastAPI) that reads the ServingPolicy table, decides which model version(s) receive the request, fans out calls, and returns the champion result to the caller.
  3. Feature fetch layer: Before calling the model server, the proxy fetches online features from the feature store using the entity_id, assembles the input tensor, and validates it against the input schema.

Versioned Deployment Strategies

Blue-Green Swap

The current champion is blue. A new version is deployed as green (status=staging). Integration tests run against green. When tests pass, the ServingPolicy record is updated atomically: champion_version = green_id. All new requests go to green. Blue is kept alive for 30 minutes to handle in-flight requests, then retired.

Canary Deployment

Set challenger_version = new_id and challenger_weight = 5 (5% of traffic). Monitor latency, error rate, and prediction distribution. Ramp challenger_weight to 50, then 100. If metrics degrade at any step, set challenger_weight back to 0 and rollback.

Rollback

Keep the previous champion record. On metric degradation — latency p99 spike, error rate increase, or prediction drift alert — the routing proxy flips champion_version back to the previous ID. Rollback is a database write; no artifact redeployment needed.

A/B Testing

Champion-challenger routing supports controlled experiments. The routing proxy hashes entity_id + model_name to assign users deterministically to a bucket 0-99. Users in buckets below challenger_weight receive predictions from the challenger; others receive champion predictions. Both paths write to PredictionLog with their model_version_id. Outcome events (conversions, ratings, downstream labels) are joined on entity_id to measure model impact.

Shadow Mode

Shadow mode evaluates a new model against real traffic without affecting users:

  1. Routing proxy receives a request.
  2. Sends request to champion synchronously — awaits response.
  3. Sends the same request to the shadow model asynchronously (fire and forget, with a short timeout).
  4. Returns champion prediction to caller.
  5. Logs both champion and shadow predictions to PredictionLog with respective model_version_ids.

Offline analysis compares champion vs. shadow prediction distributions, latency distributions, and error rates before any live traffic is shifted.

Inference Caching

For deterministic models and inputs that repeat frequently (e.g., pricing a fixed catalog of items), Redis caching reduces model server load significantly:

  • Cache key: infer:{model_version_id}:{sha256(input_json)}
  • TTL: set based on model update frequency and feature freshness requirements.
  • On model version change, the new model_version_id in the key automatically bypasses stale cache entries without explicit invalidation.

SQL Schema

CREATE TABLE ModelVersion (
    id              BIGSERIAL PRIMARY KEY,
    model_name      VARCHAR(255) NOT NULL,
    version         VARCHAR(100) NOT NULL,
    artifact_path   TEXT         NOT NULL,  -- s3://bucket/path/model.pt
    framework       VARCHAR(50)  NOT NULL,
    input_schema    JSONB        NOT NULL,
    metrics         JSONB,
    status          VARCHAR(50)  NOT NULL DEFAULT 'staging',
    deployed_at     TIMESTAMPTZ,
    created_at      TIMESTAMPTZ  NOT NULL DEFAULT NOW(),
    UNIQUE (model_name, version)
);

CREATE TABLE ServingPolicy (
    model_name          VARCHAR(255) PRIMARY KEY,
    champion_version    BIGINT NOT NULL REFERENCES ModelVersion(id),
    challenger_version  BIGINT REFERENCES ModelVersion(id),
    challenger_weight   INT    NOT NULL DEFAULT 0,  -- percent 0-100
    shadow_version      BIGINT REFERENCES ModelVersion(id)
);

CREATE TABLE PredictionLog (
    id               BIGSERIAL PRIMARY KEY,
    model_version_id BIGINT       NOT NULL REFERENCES ModelVersion(id),
    entity_id        VARCHAR(255) NOT NULL,
    input_hash       CHAR(64)     NOT NULL,  -- sha256 of input JSON
    prediction       JSONB        NOT NULL,
    latency_ms       INT          NOT NULL,
    predicted_at     TIMESTAMPTZ  NOT NULL DEFAULT NOW()
);
CREATE INDEX ON PredictionLog (model_version_id, predicted_at DESC);

CREATE TABLE PredictionMetric (
    model_version_id BIGINT      NOT NULL REFERENCES ModelVersion(id),
    metric_name      VARCHAR(100) NOT NULL,
    value            FLOAT        NOT NULL,
    computed_at      TIMESTAMPTZ  NOT NULL DEFAULT NOW(),
    PRIMARY KEY (model_version_id, metric_name, computed_at)
);

Python Interface

import hashlib, json, time, threading
import redis
import requests

r = redis.Redis(host="redis-serving", port=6379, decode_responses=True)

def predict(model_name: str, entity_id: str, input_data: dict) -> dict:
    """Main entry point: route request, check cache, call model server."""
    version_id, endpoint = route_request(model_name, entity_id)
    cache_key = f"infer:{version_id}:{hashlib.sha256(json.dumps(input_data, sort_keys=True).encode()).hexdigest()}"
    cached = r.get(cache_key)
    if cached:
        return json.loads(cached)
    t0 = time.time()
    resp = requests.post(endpoint + "/predictions", json=input_data, timeout=0.5)
    resp.raise_for_status()
    prediction = resp.json()
    latency_ms = int((time.time() - t0) * 1000)
    r.set(cache_key, json.dumps(prediction), ex=300)
    log_prediction(version_id, entity_id, cache_key.split(":")[-1], prediction, latency_ms)
    return prediction

def route_request(model_name: str, entity_id: str) -> tuple[int, str]:
    """Determine which model version serves this entity_id."""
    policy = _fetch_policy(model_name)
    if policy.get("challenger_version") and policy.get("challenger_weight", 0) > 0:
        bucket = int(hashlib.md5(f"{model_name}:{entity_id}".encode()).hexdigest(), 16) % 100
        if bucket  None:
    # INSERT INTO PredictionLog (model_version_id, entity_id, input_hash, prediction, latency_ms)
    pass

def detect_prediction_drift(model_version_id: int, recent_predictions: list[float], baseline_mean: float, baseline_std: float) -> bool:
    """Flag drift if recent prediction mean deviates more than 3 sigma from training baseline."""
    if not recent_predictions:
        return False
    import statistics
    sample_mean = statistics.mean(recent_predictions)
    if baseline_std == 0:
        return sample_mean != baseline_mean
    return abs(sample_mean - baseline_mean) / baseline_std > 3.0

def shadow_score(model_name: str, entity_id: str, input_data: dict) -> None:
    """Fire-and-forget shadow scoring; called from routing proxy."""
    policy = _fetch_policy(model_name)
    if not policy.get("shadow_version"):
        return
    def _score():
        try:
            endpoint = _endpoint(policy["shadow_version"])
            t0 = time.time()
            resp = requests.post(endpoint + "/predictions", json=input_data, timeout=0.5)
            prediction = resp.json()
            latency_ms = int((time.time() - t0) * 1000)
            log_prediction(policy["shadow_version"], entity_id, "", prediction, latency_ms)
        except Exception:
            pass
    threading.Thread(target=_score, daemon=True).start()

Prediction Monitoring Pipeline

A scheduled job reads recent rows from PredictionLog for each active model version, computes the prediction distribution (mean, p50, p95, positive rate for classifiers), writes to PredictionMetric, and compares against the baseline stored in ModelVersion.metrics. Alerts route to PagerDuty on threshold breach. Dashboards show prediction distribution over time alongside upstream feature drift signals from the feature store monitoring system.

Key Design Decisions

  • Routing policy is stored in the database, not in config files, so changes take effect without redeploying the proxy.
  • Shadow scoring uses fire-and-forget threads to avoid adding latency to the critical path. Shadow timeouts are strictly bounded.
  • Cache keys include model_version_id so version swaps automatically bypass stale cached predictions without explicit cache invalidation logic.
  • PredictionLog is write-heavy; partition by predicted_at monthly and archive to S3 after 30 days to control storage costs.

See also: Anthropic Interview Guide 2026: Process, Questions, and AI Safety

See also: Netflix Interview Guide 2026: Streaming Architecture, Recommendation Systems, and Engineering Excellence

See also: Meta Interview Guide 2026: Facebook, Instagram, WhatsApp Engineering

See also: Scale AI Interview Guide 2026: Data Infrastructure, RLHF Pipelines, and ML Engineering

Scroll to Top