Dynamic Programming Patterns: Tree DP, Digit DP, and Bitmask DP

Beyond Basic DP

Once you’ve mastered 1D/2D tabulation and memoization (Fibonacci, LCS, Knapsack, Coin Change), advanced interview problems require recognizing three specialist patterns: tree DP (state on tree nodes), digit DP (count integers satisfying a property), and bitmask DP (enumerate subsets efficiently). These appear at FAANG+ companies in harder coding rounds and competitive programming contests.

Tree DP

Tree DP computes a value for each subtree and combines it with parent information. The key is choosing the right state: dp[node][state] where state captures what you need to propagate upward to the parent.


# Problem: Maximum Independent Set on a Tree
# Select the maximum subset of nodes such that no two adjacent nodes are selected
# dp[node][0] = max nodes if node is NOT selected
# dp[node][1] = max nodes if node IS selected

from functools import lru_cache

def max_independent_set(n: int, edges: list[tuple[int,int]]) -> int:
    adj = [[] for _ in range(n)]
    for u, v in edges:
        adj[u].append(v)
        adj[v].append(u)

    @lru_cache(None)
    def dp(node: int, parent: int) -> tuple[int, int]:
        # Returns (max_if_not_selected, max_if_selected)
        not_sel, sel = 0, 1  # base: leaf node
        for child in adj[node]:
            if child == parent:
                continue
            child_not, child_sel = dp(child, node)
            # If node is not selected: child can be either
            not_sel += max(child_not, child_sel)
            # If node is selected: child must NOT be selected
            sel += child_not
        return not_sel, sel

    not_sel, sel = dp(0, -1)
    return max(not_sel, sel)

# Time: O(N), Space: O(N)

Rerooting Technique (All-Roots DP)

When you need the answer for every node as root (e.g., “sum of distances to all other nodes for each node”), computing DP once per root is O(N²). Rerooting computes the answer for root 0 in O(N), then adjusts for each other root in O(1) per node, achieving O(N) total.


# Problem: Sum of distances in tree (LeetCode 834)
# For each node, compute sum of distances to all other nodes

def sum_of_distances_in_tree(n: int, edges: list) -> list[int]:
    adj = [[] for _ in range(n)]
    for u, v in edges:
        adj[u].append(v)
        adj[v].append(u)

    subtree_size = [1] * n
    down_dist = [0] * n  # sum of distances to all nodes in subtree

    # Phase 1: bottom-up DFS to compute subtree sizes and down_dist
    def dfs1(node, parent):
        for child in adj[node]:
            if child == parent:
                continue
            dfs1(child, node)
            subtree_size[node] += subtree_size[child]
            # Distance from node to child subtree = sum of distances + size of subtree
            down_dist[node] += down_dist[child] + subtree_size[child]
    dfs1(0, -1)

    # Phase 2: top-down to compute answer for each node (rerooting)
    ans = [0] * n
    ans[0] = down_dist[0]

    def dfs2(node, parent):
        for child in adj[node]:
            if child == parent:
                continue
            # When we reroot from node to child:
            # Nodes in child's subtree get 1 closer, nodes outside get 1 farther
            ans[child] = ans[node] - subtree_size[child] + (n - subtree_size[child])
            dfs2(child, node)
    dfs2(0, -1)

    return ans
# Time: O(N), Space: O(N)

Digit DP

Digit DP counts integers in a range [L, R] satisfying some property (no two adjacent digits the same, digits sum to K, no digit appears > K times, etc.). The property is hard to compute directly because valid numbers are not contiguous.


# Problem: Count numbers from 1 to N with no two consecutive equal digits
# dp(pos, prev_digit, tight)
# pos: current digit position (0 = leftmost)
# prev_digit: digit placed at pos-1 (-1 = none)
# tight: whether current prefix equals N's prefix (constrains next digit)

def count_no_consecutive(n: int) -> int:
    digits = [int(d) for d in str(n)]
    L = len(digits)

    from functools import lru_cache

    @lru_cache(None)
    def dp(pos: int, prev: int, tight: bool) -> int:
        if pos == L:
            return 1  # valid number found

        limit = digits[pos] if tight else 9
        result = 0

        for d in range(0 if pos > 0 else 1, limit + 1):
            if d == prev:
                continue  # skip: same as previous digit
            # tight propagates only if we chose the maximum allowed digit
            result += dp(pos + 1, d, tight and d == limit)

        return result

    # Count from 1 to n (dp gives count of n-digit numbers with our property)
    # Subtract 1 for the number 0 if it gets counted
    return dp(0, -1, True)

# General digit DP template:
# State: (position, accumulated_property, tight_constraint, [leading_zeros])
# tight: are all chosen digits so far equal to the corresponding digits of the upper bound?
# When tight=False: any digit 0-9 is valid for remaining positions
# When tight=True: the current digit cannot exceed digits[pos]

Bitmask DP

Bitmask DP uses an integer as a set representation (bit i = 1 means element i is in the set). Enables iterating over all 2^N subsets efficiently. Applicable when N ≤ 20 (2^20 = ~1M subsets).


# Problem: Traveling Salesman Problem (TSP) — visit all cities exactly once
# State: dp[mask][i] = min cost to visit exactly the cities in mask, ending at city i
# mask = bitmask of visited cities (bit j = 1 means city j visited)

def tsp(cost: list[list[int]]) -> int:
    n = len(cost)
    INF = float("inf")
    full_mask = (1 << n) - 1  # all cities visited

    # dp[mask][i] = min cost to visit cities in mask, currently at city i
    dp = [[INF] * n for _ in range(1 << n)]
    dp[1][0] = 0  # start at city 0 (bitmask: 0b0001)

    for mask in range(1 <> u & 1):
                continue  # u not in current path
            for v in range(n):
                if mask >> v & 1:
                    continue  # v already visited
                new_mask = mask | (1 < 0:
        # process sub (a subset of mask)
        sub = (sub - 1) & mask  # next subset (in descending order)

# Count bits set in mask:
bits = bin(mask).count("1")  # or use: mask.bit_count() in Python 3.10+

# Bitmask DP applications:
# - Assignment problems (assign N tasks to N workers: min cost matching)
# - Hamiltonian path/cycle (TSP variant)
# - Set cover problems (cover all elements with minimum number of subsets)
# - Broken profile DP (count tilings of a grid — next pattern)

Profile DP (Grid Tiling)


# Problem: Count ways to tile a 4×N grid with 1×2 dominoes
# State: dp[col][profile] where profile = bitmask of which rows in current column
# are "already filled" by a horizontal domino extending from the previous column

def count_tilings(rows: int, cols: int) -> int:
    from functools import lru_cache

    @lru_cache(None)
    def place(col: int, row: int, cur_profile: int, next_profile: int) -> int:
        if row == rows:
            if col + 1 > row & 1:
            # This cell already filled (by horizontal domino from left)
            return place(col, row + 1, cur_profile, next_profile)

        result = 0
        # Option 1: place vertical domino (fills rows row and row+1 in this column)
        if row + 1 > (row + 1) & 1):
            result += place(col, row + 2, cur_profile, next_profile)
        # Option 2: place horizontal domino (fills this cell + same row in next column)
        result += place(col, row + 1, cur_profile, next_profile | (1 << row))

        return result

    return place(0, 0, 0, 0)
# Time: O(cols * 2^rows * rows), feasible for small rows (≤ 12)

Key Patterns and Recognition

  • Tree DP triggers: “for each node, compute…” on a tree; diameter, maximum path, independent set, subtree queries. If asked for all roots, consider rerooting.
  • Digit DP triggers: “count integers in [1..N] with property P” where P depends on the digits. Template: memoize on (position, accumulated_state, tight_flag).
  • Bitmask DP triggers: N ≤ 20 items, need to track subsets exactly, assignment/permutation problems, “visit all” problems. If 2^N fits in memory (~1M for N=20), bitmask DP is viable.
  • Profile DP triggers: grid tiling, “fill a grid row by row” where state is the boundary between processed and unprocessed rows.

Frequently Asked Questions

What is tree DP and when should you use the rerooting technique?

Tree DP computes a dynamic programming state for each node in a tree, typically by doing a depth-first traversal and combining children's results. The state captures what information must propagate upward to the parent. A common pattern: dp[node][0] = answer if node is excluded; dp[node][1] = answer if node is included. After the bottom-up DFS, dp[root] gives the global answer. The rerooting technique (also called tree DP with rerooting or all-roots DP) computes the answer for every node as if it were the root, in O(N) total time instead of O(N²). When to use rerooting: the problem asks "for each node U, compute [some function of the entire tree when rooted at U]." Examples: sum of distances from each node to all others, maximum path through each node, "re-root to maximize some tree property." The two-pass algorithm: (1) Down-pass (bottom-up): root the tree at node 0. Compute dp[node] for each node based on its subtree. (2) Up-pass (top-down): starting from node 0, push the "contribution from above" down to each child. For a child C of node P, the answer for C as root = the answer for P as root, adjusted by removing C's subtree contribution and adding the contribution of the rest of the tree. The sum of distances problem (LeetCode 834) is the canonical example: after computing down_dist[0] (sum of distances when rooted at 0), rerooting shows that ans[child] = ans[parent] – subtree_size[child] + (N – subtree_size[child]) — a simple O(1) adjustment per edge.

How does digit DP work and what problems can it solve?

Digit DP counts integers in a range [0, N] (or [L, R] using inclusion-exclusion) satisfying some property that depends on the digits of the number. The key insight: rather than checking each number individually (too slow for large N), we build numbers digit by digit and track accumulated state. State: dp(position, accumulated_state, tight), where position = current digit being placed (0 = most significant), accumulated_state = whatever property we're tracking (digit sum, frequency of each digit, whether any digit has appeared twice, etc.), tight = whether all previous digits equal N's corresponding digits (if tight=False, remaining digits can be 0-9 freely; if tight=True, the current digit cannot exceed N's digit at this position). Transition: at each position, try all valid digits d (0 to limit, where limit=9 if not tight, else N's digit at this position). For each choice, update accumulated_state and propagate tight. Base case: when position = len(digits), return 1 if the accumulated_state represents a valid number. Examples: count numbers with no two consecutive equal digits (state = last digit placed); count numbers with digit sum divisible by K (state = digit_sum % K); count "stepping numbers" where adjacent digits differ by exactly 1 (state = last digit); count numbers with all distinct digits (state = bitmask of used digits, N ≤ 20 digits). The tight flag is what makes digit DP correct — it ensures we only count numbers ≤ N by constraining the choice at each position when all previous choices matched N exactly.

How does bitmask DP work and what is its time complexity?

Bitmask DP represents a set of N elements as an integer where bit i = 1 if element i is in the set. This allows iterating over all 2^N subsets of N elements, with set operations in O(1): union = mask_a | mask_b, intersection = mask_a & mask_b, add element i = mask | (1 << i), remove element i = mask & ~(1 << i), check element i = (mask >> i) & 1. The DP state is dp[mask][…] where mask encodes which elements have been used/processed so far. Transition: dp[new_mask] = f(dp[mask], element i) where new_mask = mask | (1 << i). Time complexity: O(2^N × N) for most bitmask DP problems — there are 2^N possible masks, and for each mask you try adding each of N elements. Space: O(2^N × state_size). Feasibility: 2^20 = ~1M states (fine), 2^25 = ~33M (borderline), 2^30 = ~1B (too slow). N ≤ 20 is the practical limit for bitmask DP in interview settings (N ≤ 15 is more comfortable). Traveling Salesman Problem (TSP): dp[mask][city] = minimum cost to visit exactly the cities in mask ending at city. Classic O(2^N × N²) bitmask DP. Assignment problem: dp[mask] = minimum cost to assign the first popcount(mask) workers to the tasks represented by mask. O(2^N × N). Subset enumeration trick: to iterate over all subsets of a given mask, use: sub = mask; while sub > 0: process(sub); sub = (sub – 1) & mask. This visits all non-empty subsets of mask in O(3^N) total time across all masks (each element is either not in mask, in mask but not sub, or in both).

{
“@context”: “https://schema.org”,
“@type”: “FAQPage”,
“mainEntity”: [
{
“@type”: “Question”,
“name”: “What is tree DP and when should you use the rerooting technique?”,
“acceptedAnswer”: {
“@type”: “Answer”,
“text”: “Tree DP computes a dynamic programming state for each node in a tree, typically by doing a depth-first traversal and combining children’s results. The state captures what information must propagate upward to the parent. A common pattern: dp[node][0] = answer if node is excluded; dp[node][1] = answer if node is included. After the bottom-up DFS, dp[root] gives the global answer. The rerooting technique (also called tree DP with rerooting or all-roots DP) computes the answer for every node as if it were the root, in O(N) total time instead of O(N²). When to use rerooting: the problem asks “for each node U, compute [some function of the entire tree when rooted at U].” Examples: sum of distances from each node to all others, maximum path through each node, “re-root to maximize some tree property.” The two-pass algorithm: (1) Down-pass (bottom-up): root the tree at node 0. Compute dp[node] for each node based on its subtree. (2) Up-pass (top-down): starting from node 0, push the “contribution from above” down to each child. For a child C of node P, the answer for C as root = the answer for P as root, adjusted by removing C’s subtree contribution and adding the contribution of the rest of the tree. The sum of distances problem (LeetCode 834) is the canonical example: after computing down_dist[0] (sum of distances when rooted at 0), rerooting shows that ans[child] = ans[parent] – subtree_size[child] + (N – subtree_size[child]) — a simple O(1) adjustment per edge.”
}
},
{
“@type”: “Question”,
“name”: “How does digit DP work and what problems can it solve?”,
“acceptedAnswer”: {
“@type”: “Answer”,
“text”: “Digit DP counts integers in a range [0, N] (or [L, R] using inclusion-exclusion) satisfying some property that depends on the digits of the number. The key insight: rather than checking each number individually (too slow for large N), we build numbers digit by digit and track accumulated state. State: dp(position, accumulated_state, tight), where position = current digit being placed (0 = most significant), accumulated_state = whatever property we’re tracking (digit sum, frequency of each digit, whether any digit has appeared twice, etc.), tight = whether all previous digits equal N’s corresponding digits (if tight=False, remaining digits can be 0-9 freely; if tight=True, the current digit cannot exceed N’s digit at this position). Transition: at each position, try all valid digits d (0 to limit, where limit=9 if not tight, else N’s digit at this position). For each choice, update accumulated_state and propagate tight. Base case: when position = len(digits), return 1 if the accumulated_state represents a valid number. Examples: count numbers with no two consecutive equal digits (state = last digit placed); count numbers with digit sum divisible by K (state = digit_sum % K); count “stepping numbers” where adjacent digits differ by exactly 1 (state = last digit); count numbers with all distinct digits (state = bitmask of used digits, N ≤ 20 digits). The tight flag is what makes digit DP correct — it ensures we only count numbers ≤ N by constraining the choice at each position when all previous choices matched N exactly.”
}
},
{
“@type”: “Question”,
“name”: “How does bitmask DP work and what is its time complexity?”,
“acceptedAnswer”: {
“@type”: “Answer”,
“text”: “Bitmask DP represents a set of N elements as an integer where bit i = 1 if element i is in the set. This allows iterating over all 2^N subsets of N elements, with set operations in O(1): union = mask_a | mask_b, intersection = mask_a & mask_b, add element i = mask | (1 << i), remove element i = mask & ~(1 <> i) & 1. The DP state is dp[mask][…] where mask encodes which elements have been used/processed so far. Transition: dp[new_mask] = f(dp[mask], element i) where new_mask = mask | (1 < 0: process(sub); sub = (sub – 1) & mask. This visits all non-empty subsets of mask in O(3^N) total time across all masks (each element is either not in mask, in mask but not sub, or in both).”
}
}
]
}

Companies That Ask This Question

Scroll to Top