Low Level Design: ML Training Pipeline

A machine learning training pipeline is the infrastructure that moves raw data through feature engineering, distributed model training, experiment tracking, and model promotion into a registry. Getting this right is the difference between a team that iterates quickly and one that spends most of its time debugging non-reproducible experiments.

Requirements

Functional

  • Version and snapshot datasets so any past experiment can be reproduced exactly.
  • Distribute training across multiple GPUs / nodes to handle large models and datasets.
  • Track every hyperparameter, metric, artifact, and code commit for each run.
  • Promote validated models to a registry with semantic versioning and approval gates.

Non-Functional

  • Training throughput: saturate 8 × A100 GPUs with > 90% GPU utilization.
  • Experiment metadata write latency: < 100 ms per metric log call.
  • Registry availability: 99.9% (model artifacts must always be fetchable by serving).
  • Reproducibility: identical code + dataset version + seed must produce bit-identical weights.

Data Model

-- Dataset versions (pointer to immutable snapshot)
CREATE TABLE dataset_version (
    id              BIGSERIAL   PRIMARY KEY,
    dataset_name    VARCHAR(128) NOT NULL,
    version         VARCHAR(32)  NOT NULL,        -- semver or hash
    storage_uri     TEXT         NOT NULL,         -- s3://bucket/datasets/v1.2.3/
    row_count       BIGINT,
    size_bytes      BIGINT,
    schema_hash     CHAR(64)     NOT NULL,         -- SHA-256 of column schema
    split_config    JSONB        NOT NULL DEFAULT '{}',  -- train/val/test ratios
    created_at      TIMESTAMPTZ  NOT NULL DEFAULT NOW(),
    created_by      VARCHAR(128) NOT NULL,
    UNIQUE (dataset_name, version)
);

-- Experiment run
CREATE TABLE experiment_run (
    id              BIGSERIAL   PRIMARY KEY,
    experiment_name VARCHAR(128) NOT NULL,
    run_uuid        UUID         NOT NULL DEFAULT gen_random_uuid(),
    status          VARCHAR(16)  NOT NULL CHECK (status IN ('queued','running','completed','failed','killed')),
    dataset_id      BIGINT       NOT NULL REFERENCES dataset_version(id),
    code_commit     CHAR(40)     NOT NULL,         -- git SHA
    hyperparams     JSONB        NOT NULL DEFAULT '{}',
    resource_spec   JSONB        NOT NULL DEFAULT '{}',  -- {gpu_count, gpu_type, memory_gb}
    started_at      TIMESTAMPTZ,
    finished_at     TIMESTAMPTZ,
    artifact_uri    TEXT,                          -- s3://bucket/runs/{run_uuid}/
    created_at      TIMESTAMPTZ  NOT NULL DEFAULT NOW(),
    created_by      VARCHAR(128) NOT NULL
);
CREATE INDEX idx_run_experiment ON experiment_run(experiment_name, status);

-- Metrics logged during training (time-series per run)
CREATE TABLE run_metric (
    id              BIGSERIAL   PRIMARY KEY,
    run_id          BIGINT       NOT NULL REFERENCES experiment_run(id),
    metric_name     VARCHAR(128) NOT NULL,
    step            INT          NOT NULL,
    value           DOUBLE PRECISION NOT NULL,
    logged_at       TIMESTAMPTZ  NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_metric_run_step ON run_metric(run_id, metric_name, step);

-- Model registry
CREATE TABLE model_version (
    id              BIGSERIAL   PRIMARY KEY,
    model_name      VARCHAR(128) NOT NULL,
    version         VARCHAR(32)  NOT NULL,
    stage           VARCHAR(16)  NOT NULL CHECK (stage IN ('staging','champion','archived')),
    run_id          BIGINT       NOT NULL REFERENCES experiment_run(id),
    artifact_uri    TEXT         NOT NULL,
    framework       VARCHAR(32)  NOT NULL,         -- pytorch, tensorflow, sklearn
    input_schema    JSONB        NOT NULL DEFAULT '{}',
    output_schema   JSONB        NOT NULL DEFAULT '{}',
    promoted_at     TIMESTAMPTZ,
    promoted_by     VARCHAR(128),
    notes           TEXT,
    UNIQUE (model_name, version)
);
CREATE INDEX idx_model_stage ON model_version(model_name, stage);

Core Workflow

User submits TrainingJob spec (YAML)
    |
    v
[Job Scheduler]
  - Validates dataset_version exists and is healthy
  - Resolves resource_spec → available cluster nodes
  - Creates experiment_run row (status='queued')
  - Enqueues job to work queue (Redis / SQS)
    |
    v
[Worker Pool — one coordinator + N workers]
  - Coordinator pulls job, sets status='running'
  - Initializes distributed process group (NCCL / Gloo)
  - Each worker downloads its data shard from object storage
    |
    v
[Training Loop — each worker]
  for epoch in range(config.epochs):
      for batch in dataloader:            # prefetched, pinned memory
          optimizer.zero_grad()
          loss = model(batch)
          loss.backward()                 # gradients computed locally
          all_reduce(gradients, op=SUM)   # NCCL ring all-reduce
          optimizer.step()
          if step % log_interval == 0:
              log_metric('train_loss', loss, step)  # async write to metrics DB
      validate()
      checkpoint_if_best()
    |
    v
[Post-training]
  - Coordinator uploads final checkpoint to artifact_uri
  - Sets experiment_run.status = 'completed'
  - Emits event → Model Registry service
    |
    v
[Model Registry]
  - Creates model_version row (stage='staging')
  - Triggers automated evaluation suite (accuracy, latency, bias checks)
  - Human reviewer approves → stage transitions to 'champion'
  - Previous champion transitions to 'archived'

Dataset Versioning

Every dataset is stored as an immutable snapshot in object storage. The dataset_version table records a schema_hash (SHA-256 of column names and types) so that pipeline code can assert schema compatibility at job submission time. Splits (train/val/test) are stored as index files rather than copying data, keeping storage overhead negligible. A dataset is never deleted; it is only marked deprecated in the registry.

Distributed Training Architecture

Data Parallelism (DDP)

Each GPU holds a full copy of the model. The global batch is sharded across GPUs. After each backward pass, an all-reduce operation averages gradients across all workers. This is the dominant strategy for models that fit in a single GPU’s memory.

Model Parallelism (Pipeline / Tensor)

For models too large for one GPU (e.g., large language models), layers are split across devices. Pipeline parallelism assigns contiguous layer groups to each GPU and uses micro-batching to keep the pipeline filled. Tensor parallelism splits individual weight matrices across GPUs and requires custom CUDA kernels.

Gradient Checkpointing

Recomputes activations during the backward pass instead of storing them, trading compute time for GPU memory. Enables training models 3–8x larger than available VRAM at a ~30% throughput cost.

Experiment Tracking

Every metric log call is a non-blocking async write: the training process pushes to an in-process queue; a background thread batches inserts into run_metric every 500 ms. This keeps the hot path unaffected by database latency. Hyperparameters are serialized to JSONB in experiment_run.hyperparams at job start; they are never mutated after the run begins.

Artifacts (checkpoints, plots, ONNX exports) are uploaded to object storage under a deterministic path: s3://bucket/runs/{run_uuid}/artifacts/. The experiment tracking UI resolves artifact paths via the database; no artifact paths are hard-coded in application code.

Key Design Decisions & Trade-offs

Decision Choice Alternative Reason
Gradient sync All-reduce (ring) Parameter server No single bottleneck; bandwidth scales with worker count
Checkpointing Best-metric checkpoint Every N steps Reduces artifact storage; resumes from best known state
Metric storage Relational + JSONB Time-series DB (InfluxDB) Simpler ops; JSONB handles arbitrary metric names without schema changes
Dataset storage Immutable object store snapshots Delta Lake / Iceberg Simpler; ML workloads are read-only; no need for ACID mutation
Model staging staging → champion Direct to prod Mandatory evaluation gate prevents bad models reaching serving

Failure Handling & Edge Cases

  • Worker node crash mid-training: coordinator detects heartbeat timeout; remaining workers are killed; job status set to FAILED. The orchestrator retries from the last saved checkpoint (saved every N steps or on best metric). Checkpoint URI is recorded in experiment_run so restart is automatic.
  • Gradient explosion: global gradient norm is clipped before the optimizer step (clip_grad_norm_(model.parameters(), max_norm=1.0)). If loss becomes NaN the training loop halts and marks the run FAILED, preserving the last valid checkpoint.
  • NCCL deadlock: all collective operations use a timeout; if a worker hangs past the timeout the coordinator detects it via the heartbeat and triggers a full restart.
  • Data shard imbalance: if one shard has far more samples than others (e.g., due to class imbalance or corrupted split) the slow worker becomes a straggler. Mitigation: pre-compute shard sizes at dataset version creation time and rebalance if any shard deviates more than 5%.
  • Disk full on checkpoint node: checkpoints are streamed directly to object storage rather than written to local disk first. A streaming multipart upload ensures no local disk pressure.

Scalability Considerations

  • Elastic worker pools: use spot / preemptible instances for training workers. The checkpoint-and-resume design means spot interruptions cost at most N checkpoint-interval steps of compute.
  • Data prefetching: each worker runs a multi-process DataLoader that prefetches the next batch to CPU RAM while the GPU trains on the current batch. Target: GPU utilization > 90% with zero GPU-stall waiting for data.
  • Mixed-precision training (FP16/BF16): halves memory bandwidth and doubles throughput on Tensor Core hardware. Use loss scaling to prevent FP16 underflow in gradients.
  • Metric write throughput: if thousands of runs log simultaneously, metric inserts can saturate a single Postgres instance. Shard metric writes across multiple write nodes by run_id % N, or use a time-series store (InfluxDB, Prometheus remote write) as the hot tier with periodic export to the relational store.
  • Artifact deduplication: identical model weights (same hash) produced by reruns are stored once and referenced by multiple model_version rows. A content-addressable storage layer (based on SHA-256 of the file) handles deduplication transparently.

Summary

A production ML training pipeline is built on three guarantees: reproducibility (pinned dataset versions, code commits, seeds), observability (async experiment tracking with full hyperparameter and metric history), and resilience (checkpoint-and-resume, automatic retry, spot-instance tolerance). Distributed training adds coordination complexity — gradient all-reduce, straggler mitigation, NCCL timeout handling — but the data-parallel DDP pattern handles the vast majority of real-world models cleanly. The model registry’s staging gate is the most important safety mechanism: it ensures no model reaches production without passing an automated evaluation suite and a human approval step.

{
“@context”: “https://schema.org”,
“@type”: “FAQPage”,
“mainEntity”: [
{
“@type”: “Question”,
“name”: “What is an ML training pipeline and what are its main stages?”,
“acceptedAnswer”: {
“@type”: “Answer”,
“text”: “An ML training pipeline is an automated, reproducible workflow that takes raw data and produces a trained, validated model artifact. Its main stages are: data ingestion and validation (sourcing and checking data quality), feature engineering (transforming raw inputs into model-ready features), model training (fitting the model on processed data), evaluation (measuring performance against held-out data using defined metrics), and registration or promotion (storing the model artifact in a registry if it meets quality thresholds). Orchestration tools such as Kubeflow, MLflow, or Apache Airflow are commonly used to coordinate these stages.”
}
},
{
“@type”: “Question”,
“name”: “How is dataset versioning handled in an ML training pipeline?”,
“acceptedAnswer”: {
“@type”: “Answer”,
“text”: “Dataset versioning ensures that a specific snapshot of training data can be reproduced for debugging, auditing, or retraining. Common approaches include storing datasets as immutable, content-addressed objects in object storage (e.g., S3) with a manifest file that records paths and checksums, using a dedicated data versioning tool such as DVC or Delta Lake’s time-travel feature, or appending a version tag or timestamp to dataset identifiers. The chosen version is logged alongside hyperparameters and code commit hashes in the experiment tracker so that any run can be fully reconstructed.”
}
},
{
“@type”: “Question”,
“name”: “How does distributed training work across multiple GPUs or nodes?”,
“acceptedAnswer”: {
“@type”: “Answer”,
“text”: “Distributed training splits the computational workload of model training across multiple GPUs or machines to reduce wall-clock time. The two primary strategies are data parallelism, where each worker holds a full model copy and processes a different mini-batch, then synchronizes gradients via all-reduce (e.g., NCCL ring all-reduce) before updating weights; and model parallelism, where the model itself is split across devices, used when it’s too large to fit on a single GPU. Frameworks such as PyTorch DDP, Horovod, and DeepSpeed implement these strategies, and parameter servers or peer-to-peer collective communication are used to keep replicas consistent.”
}
},
{
“@type”: “Question”,
“name”: “How are experiments tracked and models promoted to production?”,
“acceptedAnswer”: {
“@type”: “Answer”,
“text”: “Experiment tracking tools (MLflow, Weights & Biases, Neptune) log each training run’s hyperparameters, dataset version, code commit, metrics, and output artifacts to a central store, making runs comparable and reproducible. Once a run completes, automated evaluation gates compare its metrics (accuracy, AUC, latency) against a baseline or champion model. If it passes, the artifact is registered in a model registry with a version tag. Promotion to production is then triggered manually or automatically by moving the registry alias (e.g., ‘staging’ -> ‘production’), which downstream serving infrastructure watches to load the new weights.”
}
}
]
}

See also: Netflix Interview Guide 2026: Streaming Architecture, Recommendation Systems, and Engineering Excellence

See also: Databricks Interview Guide 2026: Spark Internals, Delta Lake, and Lakehouse Architecture

See also: Uber Interview Guide 2026: Dispatch Systems, Geospatial Algorithms, and Marketplace Engineering

Scroll to Top