Minimum Spanning Tree problems test whether you can connect all nodes in a weighted graph at minimum total cost — a common interview topic at Amazon, Google, and Microsoft. Two algorithms solve it: Kruskal’s (edge-sorted, Union-Find) and Prim’s (greedy growth from a vertex). Know both.
What is a Minimum Spanning Tree?
A spanning tree of a graph is a subgraph that connects all V vertices with exactly V-1 edges and no cycles. A Minimum Spanning Tree (MST) is the spanning tree with minimum total edge weight.
Applications: network design (minimum cable to connect all offices), clustering algorithms, approximation algorithms for TSP.
Kruskal’s Algorithm
Sort edges by weight. Add each edge if it doesn’t create a cycle (checked via Union-Find).
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y) -> bool:
px, py = self.find(x), self.find(y)
if px == py:
return False # Same component — adding this edge creates a cycle
if self.rank[px] tuple[int, list]:
"""
Kruskal's MST algorithm.
edges: list of (weight, u, v)
Returns: (total_weight, mst_edges)
Time: O(E log E) for sorting, O(E * alpha(V)) for Union-Find
"""
edges = sorted(edges) # Sort by weight
uf = UnionFind(n)
mst_weight = 0
mst_edges = []
for weight, u, v in edges:
if uf.union(u, v):
mst_weight += weight
mst_edges.append((u, v, weight))
if len(mst_edges) == n - 1:
break # MST is complete: V-1 edges
if len(mst_edges) != n - 1:
return -1, [] # Graph is disconnected
return mst_weight, mst_edges
# Example: 4 nodes, edges as (weight, u, v)
n = 4
edges = [(1, 0, 1), (3, 0, 2), (4, 0, 3), (2, 1, 2), (5, 1, 3), (6, 2, 3)]
weight, mst = kruskal(n, edges)
print(f"MST weight: {weight}") # 7 (edges: 0-1=1, 1-2=2, 0-3=4... wait, let's verify)
# (0,1,1) + (1,2,2) + (0,3,4) = 7
Prim’s Algorithm
Grow the MST from a starting vertex. At each step, add the minimum-weight edge crossing the cut between the MST and the rest of the graph.
import heapq
def prim(n: int, graph: list[list[tuple]], start: int = 0) -> tuple[int, list]:
"""
Prim's MST algorithm with min-heap.
graph: adjacency list — graph[u] = [(weight, v), ...]
Returns: (total_weight, mst_edges)
Time: O((V + E) log V), Space: O(V + E)
"""
in_mst = [False] * n
min_heap = [(0, start, -1)] # (edge_weight, to_node, from_node)
mst_weight = 0
mst_edges = []
while min_heap and len(mst_edges) < n:
weight, u, from_node = heapq.heappop(min_heap)
if in_mst[u]:
continue
in_mst[u] = True
mst_weight += weight
if from_node != -1:
mst_edges.append((from_node, u, weight))
for w, v in graph[u]:
if not in_mst[v]:
heapq.heappush(min_heap, (w, v, u))
if len(mst_edges) != n - 1:
return -1, [] # Disconnected graph
return mst_weight, mst_edges
Kruskal’s vs Prim’s
| Property | Kruskal’s | Prim’s |
|---|---|---|
| Time | O(E log E) | O((V + E) log V) with heap |
| Best for | Sparse graphs (E ≈ V) | Dense graphs (E ≈ V²) |
| Approach | Edge-centric: sort all edges | Vertex-centric: grow from seed |
| Data structure | Union-Find | Min-heap (priority queue) |
| Disconnected graphs | Finds MSF (forest) naturally | Need to restart for each component |
LeetCode Problems
- LeetCode 1584: Min Cost to Connect All Points (Prim’s or Kruskal’s on complete graph)
- LeetCode 1135: Connecting Cities With Minimum Cost (Kruskal’s)
- LeetCode 1168: Optimize Water Distribution (add virtual node for wells)
Min Cost to Connect All Points
def min_cost_connect_points(points: list[list[int]]) -> int:
"""
LeetCode 1584: MST where edge weight = Manhattan distance.
Prim's is more efficient here (O(V^2) without heap vs O(E log V) with heap;
for dense complete graph, O(V^2) with key array may be faster than heap).
"""
n = len(points)
in_mst = [False] * n
min_dist = [float('inf')] * n
min_dist[0] = 0
total = 0
for _ in range(n):
# Find vertex with minimum distance not yet in MST
u = min((i for i in range(n) if not in_mst[i]), key=lambda i: min_dist[i])
in_mst[u] = True
total += min_dist[u]
# Update distances to remaining vertices
for v in range(n):
if not in_mst[v]:
dist = abs(points[u][0] - points[v][0]) + abs(points[u][1] - points[v][1])
if dist < min_dist[v]:
min_dist[v] = dist
return total
Related Graph Topics
- Bellman-Ford: Shortest Path — shortest path vs MST: Bellman-Ford finds minimum distance from source to all nodes; MST minimizes total edge weight across all nodes
- Detect a Cycle in a Graph — Kruskal’s MST uses Union-Find for cycle detection; adding an edge that connects two components already in the MST would create a cycle