System Design Interview: Machine Learning Training Infrastructure

Why ML Training Infrastructure Is a System Design Interview Topic

Companies building AI products at scale — OpenAI, Google, Meta, Databricks, and any company serious about ML — need robust training infrastructure. Interview candidates for ML platform, MLOps, and senior ML engineering roles are expected to design systems that schedule GPU jobs, coordinate distributed training across hundreds of nodes, manage failures gracefully, and enable reproducible experiments. This post covers the core components and key design decisions.

Scale Estimates

Training GPT-3 (175B parameters) required 10,000 V100 GPUs for approximately 14 days. Modern LLM training runs use 1,000–16,000 A100/H100 GPUs. At 5 petaflops per H100 × 8,000 GPUs, this is 40 exaflops — roughly 1 million times a consumer laptop. For smaller-scale ML teams: a typical company trains models on clusters of 8–128 GPUs, running experiments that take minutes to days. The infrastructure must handle hundreds of concurrent jobs from dozens of researchers.

Key Components

1. Cluster Scheduler (Kubernetes + GPU Operator)

GPU jobs are scheduled on Kubernetes with NVIDIA’s GPU operator. Each training job is a Kubernetes Job or custom resource (Kubeflow Training Operator: TFJob, PyTorchJob). The scheduler allocates GPU nodes to jobs based on requested resources (e.g., 8xA100-80GB, 100GB shared memory for fast inter-GPU communication). Key scheduling features:

  • Gang scheduling: distributed training requires ALL nodes to start simultaneously — partial allocation is useless. Gang scheduling waits until all requested GPUs are available before starting any pod.
  • Topology-aware scheduling: prefer co-locating pods on nodes connected by NVLink (GPU-to-GPU at 600 GB/s) over PCIe (64 GB/s). Cross-node communication uses InfiniBand (400 Gb/s). The scheduler uses node labels to maximize interconnect bandwidth.
  • Preemption: high-priority jobs preempt lower-priority ones. Fine-tuning a production model preempts a researcher’s exploratory run.

2. Distributed Training Strategies

Large models cannot fit in a single GPU’s memory (80GB for A100). Three parallelism strategies:

Data Parallelism (DDP): the model is copied on each GPU; each GPU processes a different mini-batch. After each backward pass, gradients are all-reduced (sum across all GPUs using NCCL’s ring-allreduce). Each GPU updates the same model weights identically. Works well when the model fits in a single GPU. PyTorch DDP is the standard implementation. For N GPUs, throughput scales linearly (near-linear) with batch size N×batch_per_gpu.

Model (Tensor) Parallelism: each layer’s weight matrix is split across GPUs. For a transformer attention layer with 4 attention heads on 4 GPUs, each GPU computes 1 head. This requires dense all-to-all communication at each layer boundary. Megatron-LM implements efficient tensor parallelism for transformers. Used when a single layer’s weights exceed GPU memory.

Pipeline Parallelism: different layers are placed on different GPUs (GPU 0: layers 1-8, GPU 1: layers 9-16, etc.). Each GPU processes a micro-batch and passes activations forward to the next GPU — like a factory pipeline. Reduces memory per GPU but introduces pipeline bubbles (idle time while waiting for micro-batches). GPipe and PipeDream manage micro-batch scheduling to maximize utilization.

3D Parallelism: combine all three. Megatron-Turing NLG and PaLM use data × tensor × pipeline parallelism simultaneously, each dimension handled by a different level of the GPU hierarchy (node × GPU × layer).

3. Checkpointing and Fault Tolerance

A 14-day training run will experience node failures. Without checkpointing, a single GPU failure loses the entire run. Checkpointing saves model weights, optimizer state (Adam momentum buffers — same size as weights), and training metadata (step number, learning rate schedule) to shared storage (distributed filesystem like Lustre, or object storage like S3).

Checkpoint interval tradeoff: checkpoint every hour → at most 1 hour of work lost on failure; but checkpoint overhead (saving 100GB of state) takes minutes for very large models. Asynchronous checkpointing: a background thread saves the checkpoint while training continues — the checkpoint is taken from a snapshot of weights and continues without blocking forward passes. PyTorch FSDP (Fully Sharded Data Parallel) integrates async checkpointing.

On node failure: the job controller detects pod failure, terminates all other pods (gang semantics), and relaunches the entire job, loading from the last checkpoint. Spot/preemptible instances (3× cheaper than on-demand) are viable with frequent checkpointing and automated relaunch.

4. Shared Storage: Training Data and Checkpoints

Training data must be delivered to GPUs faster than they can consume it. A100 at FP16 processes ~300 tokens/second. For a 300B token dataset, I/O throughput must sustain this rate across thousands of GPUs simultaneously. Solutions:

  • Lustre / GPFS: parallel distributed filesystem with 100+ GB/s aggregate throughput. Standard in on-premise GPU clusters (Summit, Frontier HPC clusters).
  • AWS FSx for Lustre + S3 backend: cache frequently accessed training data from S3 on fast NVMe SSDs. Latency-sensitive random reads come from SSD; sequential reads stream from S3.
  • Local NVMe striping: for repetitive training (multiple epochs), pre-shard the dataset onto each node’s local SSD and have each node only read its shard.

5. Experiment Tracking

Researchers run hundreds of experiments varying hyperparameters (learning rate, batch size, architecture), datasets, and regularization. Without tracking, results are lost in log files and experiments are not reproducible. MLflow and Weights & Biases (W&B) track:

  • Hyperparameters (logged at job start)
  • Metrics per step (loss, gradient norm, learning rate, tokens/second — logged every N steps)
  • Artifacts (model checkpoints, evaluation results)
  • Git commit hash and code diff (reproducibility)
  • System metrics (GPU utilization, memory usage — detect underutilization)

W&B stores metrics in a time-series database; the UI shows training curves, hyperparameter sweep visualizations (parallel coordinates plot), and model version comparisons. API integrations allow automated hyperparameter search (Optuna, Ray Tune) to query past results and propose next trials.

6. GPU Utilization Monitoring

Poorly written training code often leaves GPUs at 20-30% utilization (I/O-bound data loading, small batch sizes, excessive CPU-GPU transfers). NVIDIA’s dcgm-exporter exposes GPU metrics (sm_active, mem_bw_utilized, nvlink_bandwidth) to Prometheus. Dashboards alert when utilization drops below 80% for more than 5 minutes. Common fixes: increase DataLoader workers (–num_workers), use pinned memory (pin_memory=True), prefetch next batch while GPU is processing current batch, use TF32/BF16 instead of FP32.

Key Interview Design Decisions

  • Use gang scheduling — distributed training requires all-or-nothing GPU allocation
  • Choose parallelism strategy based on model size vs GPU memory: DDP → tensor → pipeline → 3D
  • Checkpoint every 30-60 minutes asynchronously; use spot instances for 3× cost savings
  • Use Lustre or FSx for training data I/O — S3 alone is too slow for GPU-scale reads
  • Track all experiments in W&B or MLflow for reproducibility and hyperparameter search
  • Monitor GPU SM utilization; anything below 70% warrants profiling

Frequently Asked Questions

What is the difference between data parallelism, model parallelism, and pipeline parallelism in distributed ML training?

Data parallelism (DDP): the full model is copied to each GPU; each GPU processes a different mini-batch of training data. After each backward pass, gradients are averaged across all GPUs using all-reduce (NCCL ring-allreduce). All GPUs update identical copies of the model. Works when the model fits in a single GPU. Scales throughput linearly with the number of GPUs. Model (tensor) parallelism: a single layer is split across multiple GPUs. For a matrix multiplication Y = XW, the weight matrix W is partitioned column-wise across N GPUs, each computing X × W_partition. Results are gathered via all-reduce. Used when a single layer is too large for one GPU. Requires dense inter-GPU communication at each layer. Pipeline parallelism: different layers are assigned to different GPUs in a sequential chain. GPU 0 processes layers 1-4, GPU 1 processes layers 5-8, etc. Activations are passed forward between GPUs. Multiple micro-batches are in-flight simultaneously to keep all GPUs busy (otherwise they sit idle waiting for the previous GPU). Introduces pipeline bubbles (inefficiency) at the start and end of each batch. Real training at scale (GPT-4, PaLM) uses 3D parallelism combining all three: data × tensor × pipeline, each mapped to a different level of the GPU interconnect hierarchy.

How do you handle GPU training job failures and spot instance preemptions?

Distributed training jobs can run for days or weeks, making failure recovery essential. The primary tool is periodic checkpointing: save model weights, optimizer state (Adam moment buffers), learning rate schedule, and current step number to persistent storage (S3, GCS, Lustre). Checkpoint every 15-60 minutes depending on the job duration and storage cost. On failure: the job controller (Kubernetes, Slurm) detects the pod failure, terminates all worker pods (gang semantics — a partial cluster cannot resume training), and relaunches the full job. On relaunch, each worker loads the latest checkpoint and resumes from the saved step. For spot/preemptible instances (3× cheaper than on-demand): configure automated checkpoint-on-preemption. AWS spot instances provide a 2-minute warning via instance metadata before termination. A signal handler saves an emergency checkpoint when the warning is detected. The job is relaunched on new spot or on-demand instances. With checkpointing every 15 minutes and automated relaunch, a long training run can use 90%+ spot instances while maintaining forward progress even with frequent preemptions. PyTorch FSDP (Fully Sharded Data Parallel) supports asynchronous checkpointing — saving happens in a background thread without pausing training.

What is GPU utilization and how do you diagnose low GPU utilization in training?

GPU SM (Streaming Multiprocessor) utilization measures what percentage of time the GPU's compute units are actually executing kernels (matrix multiplications, activations). Healthy training runs achieve 70-90% SM utilization. Common causes of low utilization: (1) CPU data loading bottleneck — the GPU is waiting for the next batch from disk/CPU. Fix: increase DataLoader workers (num_workers=8+), use pin_memory=True to accelerate host-to-device transfers, prefetch next batch while GPU processes current one using prefetch_factor. (2) Small batch size — the GPU has fewer operations to parallelize. Fix: increase batch size (gradient accumulation to simulate larger batches when memory is limited). (3) Python GIL overhead from CPU-side preprocessing — move preprocessing to GPU using DALI or cuCIM. (4) Excessive CPU-GPU synchronization — avoid .item() calls inside training loops (forces CPU-GPU sync); use async data logging. Diagnosis tools: NVIDIA Nsight Systems (nsys) profiles the full training timeline showing compute vs I/O vs CPU gaps; nvitop and dcgm-exporter expose real-time GPU metrics; PyTorch Profiler integrates directly into training code and generates Chrome trace files showing where time is spent.

{
“@context”: “https://schema.org”,
“@type”: “FAQPage”,
“mainEntity”: [
{
“@type”: “Question”,
“name”: “What is the difference between data parallelism, model parallelism, and pipeline parallelism in distributed ML training?”,
“acceptedAnswer”: {
“@type”: “Answer”,
“text”: “Data parallelism (DDP): the full model is copied to each GPU; each GPU processes a different mini-batch of training data. After each backward pass, gradients are averaged across all GPUs using all-reduce (NCCL ring-allreduce). All GPUs update identical copies of the model. Works when the model fits in a single GPU. Scales throughput linearly with the number of GPUs. Model (tensor) parallelism: a single layer is split across multiple GPUs. For a matrix multiplication Y = XW, the weight matrix W is partitioned column-wise across N GPUs, each computing X × W_partition. Results are gathered via all-reduce. Used when a single layer is too large for one GPU. Requires dense inter-GPU communication at each layer. Pipeline parallelism: different layers are assigned to different GPUs in a sequential chain. GPU 0 processes layers 1-4, GPU 1 processes layers 5-8, etc. Activations are passed forward between GPUs. Multiple micro-batches are in-flight simultaneously to keep all GPUs busy (otherwise they sit idle waiting for the previous GPU). Introduces pipeline bubbles (inefficiency) at the start and end of each batch. Real training at scale (GPT-4, PaLM) uses 3D parallelism combining all three: data × tensor × pipeline, each mapped to a different level of the GPU interconnect hierarchy.”
}
},
{
“@type”: “Question”,
“name”: “How do you handle GPU training job failures and spot instance preemptions?”,
“acceptedAnswer”: {
“@type”: “Answer”,
“text”: “Distributed training jobs can run for days or weeks, making failure recovery essential. The primary tool is periodic checkpointing: save model weights, optimizer state (Adam moment buffers), learning rate schedule, and current step number to persistent storage (S3, GCS, Lustre). Checkpoint every 15-60 minutes depending on the job duration and storage cost. On failure: the job controller (Kubernetes, Slurm) detects the pod failure, terminates all worker pods (gang semantics — a partial cluster cannot resume training), and relaunches the full job. On relaunch, each worker loads the latest checkpoint and resumes from the saved step. For spot/preemptible instances (3× cheaper than on-demand): configure automated checkpoint-on-preemption. AWS spot instances provide a 2-minute warning via instance metadata before termination. A signal handler saves an emergency checkpoint when the warning is detected. The job is relaunched on new spot or on-demand instances. With checkpointing every 15 minutes and automated relaunch, a long training run can use 90%+ spot instances while maintaining forward progress even with frequent preemptions. PyTorch FSDP (Fully Sharded Data Parallel) supports asynchronous checkpointing — saving happens in a background thread without pausing training.”
}
},
{
“@type”: “Question”,
“name”: “What is GPU utilization and how do you diagnose low GPU utilization in training?”,
“acceptedAnswer”: {
“@type”: “Answer”,
“text”: “GPU SM (Streaming Multiprocessor) utilization measures what percentage of time the GPU’s compute units are actually executing kernels (matrix multiplications, activations). Healthy training runs achieve 70-90% SM utilization. Common causes of low utilization: (1) CPU data loading bottleneck — the GPU is waiting for the next batch from disk/CPU. Fix: increase DataLoader workers (num_workers=8+), use pin_memory=True to accelerate host-to-device transfers, prefetch next batch while GPU processes current one using prefetch_factor. (2) Small batch size — the GPU has fewer operations to parallelize. Fix: increase batch size (gradient accumulation to simulate larger batches when memory is limited). (3) Python GIL overhead from CPU-side preprocessing — move preprocessing to GPU using DALI or cuCIM. (4) Excessive CPU-GPU synchronization — avoid .item() calls inside training loops (forces CPU-GPU sync); use async data logging. Diagnosis tools: NVIDIA Nsight Systems (nsys) profiles the full training timeline showing compute vs I/O vs CPU gaps; nvitop and dcgm-exporter expose real-time GPU metrics; PyTorch Profiler integrates directly into training code and generates Chrome trace files showing where time is spent.”
}
}
]
}

Companies That Ask This Question

Scroll to Top