Transformers are the architecture behind GPT, BERT, Claude, and every other major language model. Understanding how they work — especially the attention mechanism — is now a core expectation in ML interviews at any company doing AI work. This post explains the mechanism from first principles, with the math made concrete.
What the Interviewer Is Testing
At the junior level: can you explain what self-attention does? At the senior level: can you explain why transformers outperform RNNs, what multi-head attention adds, the difference between encoder-only and decoder-only architectures, and the key limitations (quadratic attention complexity, context length)?
The Problem Transformers Solve
Before transformers, RNNs (LSTMs, GRUs) processed sequences token by token. To understand “The trophy didn’t fit in the suitcase because it was too big” — specifically, what “it” refers to — an RNN had to carry information about “trophy” through 8 intermediate states. Long-range dependencies degraded as sequence length increased. Training was sequential, not parallelizable.
Transformers solve both problems: attention directly connects any two tokens regardless of distance, and all tokens are processed in parallel (during training).
Self-Attention: The Core Mechanism
Given a sequence of token embeddings (vectors), self-attention computes a new representation of each token by attending to all other tokens in the sequence.
For each token, compute three vectors:
- Query (Q): “What am I looking for?”
- Key (K): “What do I contain?”
- Value (V): “What do I broadcast if selected?”
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q, K, V: shape (batch, seq_len, d_k)
Returns: attended output, attention weights
"""
d_k = Q.size(-1)
# Step 1: Compute attention scores — how much does each query attend to each key?
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Shape: (batch, seq_len, seq_len)
# Step 2: Apply causal mask for decoder (prevent attending to future tokens)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 3: Softmax — convert scores to probabilities (attention weights)
attn_weights = F.softmax(scores, dim=-1)
# Each row sums to 1.0 — these are the "how much to attend to each token" weights
# Step 4: Weighted sum of values
output = torch.matmul(attn_weights, V)
# Shape: (batch, seq_len, d_k) — new representation for each token
return output, attn_weights
The division by √d_k prevents the dot products from growing too large (which would push softmax into regions with near-zero gradients). This is the “scaled” in “scaled dot-product attention.”
Intuition for “The trophy…it”: when computing the representation of “it,” the attention mechanism learns high weights toward “trophy” and “suitcase” — the model attends to both and uses context to disambiguate. The attention weight matrix is the model’s learned “where to look” for each token.
Multi-Head Attention
A single attention head looks for one type of relationship. Multi-head attention runs H independent attention heads in parallel, each with its own Q, K, V projection matrices, then concatenates and projects the results.
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.W_q = torch.nn.Linear(d_model, d_model)
self.W_k = torch.nn.Linear(d_model, d_model)
self.W_v = torch.nn.Linear(d_model, d_model)
self.W_o = torch.nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch, seq_len, d_model = x.shape
# Project to Q, K, V and reshape for multi-head
Q = self.W_q(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Shape: (batch, num_heads, seq_len, d_k)
# Attention per head (runs in parallel)
attn_output, _ = scaled_dot_product_attention(Q, K, V, mask)
# Shape: (batch, num_heads, seq_len, d_k)
# Concatenate heads and project
attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
return self.W_o(attn_output)
In BERT-base: 12 heads × 64 dimensions each = 768 total dimensions. Each head can learn a different type of relationship: one head might focus on syntactic dependencies (subject-verb), another on coreference (“it” → “trophy”), another on local context.
Positional Encoding
Attention is permutation-invariant — “the cat sat on the mat” and “mat the on sat cat the” produce the same attention scores without positional information. Transformers add positional encodings to token embeddings to inject order information.
Original transformer (Vaswani et al., 2017) used sinusoidal encoding:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
Modern LLMs use RoPE (Rotary Position Embedding) — positional information is incorporated directly into the attention dot product by rotating Q and K vectors. RoPE generalizes better to longer sequences than the training context length. GPT-NeoX, LLaMA, and most modern open-source models use RoPE.
The Full Transformer Block
def transformer_block(x, mask=None):
# 1. Multi-head self-attention (with residual connection + LayerNorm)
attn_out = multi_head_attention(x, mask)
x = layer_norm(x + attn_out) # residual connection prevents vanishing gradients
# 2. Feed-forward network (position-wise MLP)
ff_out = feed_forward(x) # Linear → ReLU/GELU → Linear
x = layer_norm(x + ff_out)
return x
The feed-forward sublayer applies the same 2-layer MLP independently to each token position. Typically 4× wider than d_model (BERT-base: 768 → 3072 → 768). This is where most of the model’s parameter count lives and where factual knowledge is thought to be “stored.”
Encoder vs Decoder vs Encoder-Decoder
| Architecture | Attention type | Examples | Best for |
|---|---|---|---|
| Encoder-only | Bidirectional (each token sees all) | BERT, RoBERTa | Classification, NER, embeddings |
| Decoder-only | Causal (each token sees only past) | GPT-4, Claude, LLaMA | Text generation, completion |
| Encoder-Decoder | Bidirectional enc + causal dec + cross-attention | T5, BART, original GPT | Translation, summarization, seq2seq |
The causal mask in decoder-only models makes each token attend only to preceding tokens — this is what enables autoregressive generation: generate one token at a time, each conditioned on all previous tokens.
Why Transformers Beat RNNs
- Parallelism: All tokens processed simultaneously during training (with teacher forcing). RNNs are inherently sequential. Transformers train orders of magnitude faster on modern GPU/TPU hardware.
- Long-range dependencies: Direct connection between any two tokens in O(1) operations vs O(n) for RNNs. The attention pattern in the trophy sentence is learned in one layer.
- Gradient flow: Residual connections provide direct gradient paths to every layer. RNNs struggle with gradients vanishing over hundreds of steps.
The Quadratic Problem
Self-attention computes a (seq_len × seq_len) matrix. For a 4K context window: 4,000 × 4,000 = 16M scores per head. Memory: O(n²). Time: O(n²). This is why early GPT models had 512–2048 token contexts. Scaling to 100K+ tokens requires algorithmic improvements:
- FlashAttention: Reorders attention computation to use SRAM more efficiently. Same outputs, 3–5× faster, O(n) memory (not O(n²)) on GPU. Standard in production since 2022.
- Sparse attention: Only attend to nearby tokens + a few global tokens (Longformer, BigBird). O(n) complexity.
- Linear attention: Reformulate as a kernel to avoid the full matrix. Theoretical O(n), but often slower in practice than FlashAttention for typical context lengths.
Interview Follow-ups
- What happens if you remove positional encoding? Can the model still function, and for what tasks?
- Explain the difference between self-attention and cross-attention (in encoder-decoder models).
- Why is LayerNorm applied before attention in modern models (Pre-LN) rather than after (Post-LN as in the original paper)?
- How does KV-cache enable fast inference? What’s the memory tradeoff?
- A model has 12 attention heads. What might each head be learning?
Related ML Topics
- How Backpropagation Works — the chain rule that trains transformer weights; residual connections make backprop through 96 layers stable
- Fine-tuning LLMs vs Training from Scratch — how to adapt a pretrained transformer for your task using LoRA, QLoRA, or SFT
- Gradient Descent Explained — the optimizer (AdamW) consuming transformer gradients; why learning rate warmup matters for transformers
- Overfitting and Regularization — dropout, weight decay, and label smoothing in transformer training
See also: What is RAG? — how transformer-generated embeddings power the retrieval component of RAG pipelines, and Embeddings and Vector Databases — the vector representations that transformers produce.
See also: Computer Vision Interview Questions — how Vision Transformers (ViT) apply self-attention to image patches, and CLIP for image-text alignment.
See also: NLP Interview Questions: Tokenization, Embeddings, and BERT — how to apply transformer encoders to classification, NER, and QA; BERT fine-tuning walkthrough.