Segment Tree

interview_workbook/data_structures /app/src/interview_workbook/data_structures/segment_tree.py
View Source

Algorithm Notes

Summary: Balanced tree over ranges for queries/updates.
Build: O(n), Query/Update: O(log n)
Space: O(n)
Use: Range sum/min/max, lazy propagation for range updates.

Big-O Guide

Source

class SegmentTree:
    """
    Segment Tree with lazy propagation for range add and range sum queries.

    Supports:
    - range_add(l, r, delta): add delta to every element in [l, r]
    - range_sum(l, r): sum of elements in [l, r]
    - point_set(i, val): set arr[i] = val
    - point_add(i, delta): arr[i] += delta

    Indexing: 0-indexed externally, inclusive ranges.

    Time:
    - Build: O(n)
    - Query: O(log n)
    - Update: O(log n)
    Space: O(n)

    Interview follow-ups:
    - How to support range assign (set all in [l,r] to x)? (Use separate lazy tag for assign overriding add)
    - How to support min/max queries? (Change combine func; lazy add works for min/max as well)
    - Iterative implementation vs recursive? (Iterative is cache-friendly but trickier with lazy)
    - Difference from Fenwick Tree? (Fenwick simpler; segment tree supports more general ops)
    """

    def __init__(self, data: list[int]):
        self.n = len(data)
        self.size = 1
        while self.size < self.n:
            self.size <<= 1
        self.tree = [0] * (2 * self.size)
        self.lazy = [0] * (2 * self.size)
        # Build leaves
        for i in range(self.n):
            self.tree[self.size + i] = data[i]
        # Build internal nodes
        for i in range(self.size - 1, 0, -1):
            self.tree[i] = self.tree[i << 1] + self.tree[i << 1 | 1]

    def _apply(self, idx: int, add: int, length: int) -> None:
        """Apply lazy addition to node idx representing 'length' elements."""
        self.tree[idx] += add * length
        if idx < self.size:  # has children
            self.lazy[idx] += add

    def _push(self, idx: int, left_len: int, right_len: int) -> None:
        """Push lazy value at idx to its children."""
        if self.lazy[idx] != 0:
            add = self.lazy[idx]
            self._apply(idx << 1, add, left_len)
            self._apply(idx << 1 | 1, add, right_len)
            self.lazy[idx] = 0

    def _range_add(self, idx: int, left: int, right: int, ql: int, qr: int, add: int) -> None:
        if ql > right or qr < left:
            return
        if ql <= left and right <= qr:
            self._apply(idx, add, right - left + 1)
            return
        mid = (left + right) // 2
        left_len = mid - left + 1
        right_len = right - mid
        self._push(idx, left_len, right_len)
        self._range_add(idx << 1, left, mid, ql, qr, add)
        self._range_add(idx << 1 | 1, mid + 1, right, ql, qr, add)
        self.tree[idx] = self.tree[idx << 1] + self.tree[idx << 1 | 1]

    def range_add(self, ql: int, qr: int, add: int) -> None:
        """Add 'add' to all elements in [ql, qr]."""
        if self.n == 0 or ql > qr:
            return
        ql = max(0, ql)
        qr = min(self.n - 1, qr)
        self._range_add(1, 0, self.size - 1, ql, qr, add)

    def _range_sum(self, idx: int, left: int, right: int, ql: int, qr: int) -> int:
        if ql > right or qr < left:
            return 0
        if ql <= left and right <= qr:
            return self.tree[idx]
        mid = (left + right) // 2
        left_len = mid - left + 1
        right_len = right - mid
        self._push(idx, left_len, right_len)
        return self._range_sum(idx << 1, left, mid, ql, qr) + self._range_sum(
            idx << 1 | 1, mid + 1, right, ql, qr
        )

    def range_sum(self, ql: int, qr: int) -> int:
        """Return sum of elements in [ql, qr]."""
        if self.n == 0 or ql > qr:
            return 0
        ql = max(0, ql)
        qr = min(self.n - 1, qr)
        return self._range_sum(1, 0, self.size - 1, ql, qr)

    def _point_set(self, idx: int, left: int, right: int, pos: int, val: int) -> None:
        if left == right:
            self.tree[idx] = val
            self.lazy[idx] = 0
            return
        mid = (left + right) // 2
        left_len = mid - left + 1
        right_len = right - mid
        self._push(idx, left_len, right_len)
        if pos <= mid:
            self._point_set(idx << 1, left, mid, pos, val)
        else:
            self._point_set(idx << 1 | 1, mid + 1, right, pos, val)
        self.tree[idx] = self.tree[idx << 1] + self.tree[idx << 1 | 1]

    def point_set(self, pos: int, val: int) -> None:
        """Set arr[pos] = val."""
        if pos < 0 or pos >= self.n:
            return
        self._point_set(1, 0, self.size - 1, pos, val)

    def point_add(self, pos: int, delta: int) -> None:
        """Add delta to arr[pos]."""
        self.range_add(pos, pos, delta)


class RangeMinSegmentTree:
    """
    Segment Tree for range minimum query (RMQ) with point updates.

    Supports:
    - range_min(l, r): min over [l, r]
    - point_set(i, val): set arr[i] = val

    Time:
    - Build: O(n)
    - Query/Update: O(log n)
    """

    def __init__(self, data: list[int]):
        self.n = len(data)
        self.size = 1
        while self.size < self.n:
            self.size <<= 1
        INF = 10**18
        self.INF = INF
        self.tree = [INF] * (2 * self.size)
        # Build leaves
        for i in range(self.n):
            self.tree[self.size + i] = data[i]
        # Build internal nodes
        for i in range(self.size - 1, 0, -1):
            self.tree[i] = min(self.tree[i << 1], self.tree[i << 1 | 1])

    def range_min(self, ql: int, qr: int) -> int:
        if self.n == 0 or ql > qr:
            return self.INF
        ql = max(0, ql) + self.size
        qr = min(self.n - 1, qr) + self.size
        res = self.INF
        while ql <= qr:
            if ql & 1:
                res = min(res, self.tree[ql])
                ql += 1
            if not (qr & 1):
                res = min(res, self.tree[qr])
                qr -= 1
            ql >>= 1
            qr >>= 1
        return res

    def point_set(self, pos: int, val: int) -> None:
        if pos < 0 or pos >= self.n:
            return
        idx = self.size + pos
        self.tree[idx] = val
        idx >>= 1
        while idx:
            self.tree[idx] = min(self.tree[idx << 1], self.tree[idx << 1 | 1])
            idx >>= 1


def demo():
    print("Segment Tree Demo")
    print("=" * 40)

    arr = [2, 1, 3, 4, 5, 2, 3, 1]
    print(f"Array: {arr}")

    # Range Sum with Lazy Add
    print("\nRange Sum with Lazy Propagation:")
    st = SegmentTree(arr)
    print(f"range_sum(0, 7): {st.range_sum(0, 7)}  (expected {sum(arr)})")
    print("range_add(2, 5, +3)")
    st.range_add(2, 5, 3)
    arr2 = arr[:]
    for i in range(2, 6):
        arr2[i] += 3
    print(f"range_sum(2, 5): {st.range_sum(2, 5)}  (expected {sum(arr2[2:6])})")
    print("point_set(0, 10)")
    st.point_set(0, 10)
    arr2[0] = 10
    print(f"range_sum(0, 3): {st.range_sum(0, 3)}  (expected {sum(arr2[0:4])})")

    # Range Min
    print("\nRange Minimum Query (RMQ):")
    rmq = RangeMinSegmentTree(arr)
    print(f"range_min(0, 7): {rmq.range_min(0, 7)} (expected {min(arr)})")
    print("point_set(3, 0)")
    rmq.point_set(3, 0)
    print(f"range_min(2, 4): {rmq.range_min(2, 4)} (expected {min([arr[2], 0, arr[4]])})")

    print("\nComplexity:")
    print("  Build: O(n)")
    print("  Range Query: O(log n)")
    print("  Range Update (with lazy): O(log n)")

    print("\nInterview tips:")
    print("  - Segment Trees are more general than Fenwick Trees")
    print("  - Use lazy propagation for range updates")
    print(
        "  - For range assign + range sum, maintain both 'assign' and 'add' lazy tags with priority"
    )
    print(
        "  - Iterative segment trees are compact and fast, but recursion is easier to reason about"
    )


if __name__ == "__main__":
    demo()