Skip to main content

Segment Tree

Master the Segment Tree data structure for efficient range queries and updates.

Segment Tree Fundamentals

What is a Segment Tree?

A Segment Tree is a data structure that allows efficient range queries and updates on an array. It's particularly useful for problems involving range sum, range minimum/maximum, and range updates.

Basic Structure

class SegmentTree {
private:
vector<int> tree;
vector<int> arr;
int n;

public:
SegmentTree(vector<int>& input) {
arr = input;
n = input.size();
tree.resize(4 * n);
build(1, 0, n - 1);
}

void build(int node, int start, int end) {
if (start == end) {
tree[node] = arr[start];
} else {
int mid = (start + end) / 2;
build(2 * node, start, mid);
build(2 * node + 1, mid + 1, end);
tree[node] = tree[2 * node] + tree[2 * node + 1];
}
}
};

Range Queries

Range Sum Query

// Query sum in range [l, r]
int query(int node, int start, int end, int l, int r) {
if (r < start || end < l) {
return 0; // No overlap
}
if (l <= start && end <= r) {
return tree[node]; // Complete overlap
}

int mid = (start + end) / 2;
int leftSum = query(2 * node, start, mid, l, r);
int rightSum = query(2 * node + 1, mid + 1, end, l, r);
return leftSum + rightSum;
}

Range Minimum Query

// Query minimum in range [l, r]
int queryMin(int node, int start, int end, int l, int r) {
if (r < start || end < l) {
return INT_MAX; // No overlap
}
if (l <= start && end <= r) {
return tree[node]; // Complete overlap
}

int mid = (start + end) / 2;
int leftMin = queryMin(2 * node, start, mid, l, r);
int rightMin = queryMin(2 * node + 1, mid + 1, end, l, r);
return min(leftMin, rightMin);
}

Range Maximum Query

// Query maximum in range [l, r]
int queryMax(int node, int start, int end, int l, int r) {
if (r < start || end < l) {
return INT_MIN; // No overlap
}
if (l <= start && end <= r) {
return tree[node]; // Complete overlap
}

int mid = (start + end) / 2;
int leftMax = queryMax(2 * node, start, mid, l, r);
int rightMax = queryMax(2 * node + 1, mid + 1, end, l, r);
return max(leftMax, rightMax);
}

Range Updates

Point Update

// Update value at index idx
void update(int node, int start, int end, int idx, int val) {
if (start == end) {
arr[idx] = val;
tree[node] = val;
} else {
int mid = (start + end) / 2;
if (idx <= mid) {
update(2 * node, start, mid, idx, val);
} else {
update(2 * node + 1, mid + 1, end, idx, val);
}
tree[node] = tree[2 * node] + tree[2 * node + 1];
}
}

Range Update with Lazy Propagation

class LazySegmentTree {
private:
vector<int> tree;
vector<int> lazy;
vector<int> arr;
int n;

public:
LazySegmentTree(vector<int>& input) {
arr = input;
n = input.size();
tree.resize(4 * n);
lazy.resize(4 * n);
build(1, 0, n - 1);
}

void build(int node, int start, int end) {
if (start == end) {
tree[node] = arr[start];
} else {
int mid = (start + end) / 2;
build(2 * node, start, mid);
build(2 * node + 1, mid + 1, end);
tree[node] = tree[2 * node] + tree[2 * node + 1];
}
}

void updateRange(int node, int start, int end, int l, int r, int val) {
if (lazy[node] != 0) {
tree[node] += (end - start + 1) * lazy[node];
if (start != end) {
lazy[2 * node] += lazy[node];
lazy[2 * node + 1] += lazy[node];
}
lazy[node] = 0;
}

if (r < start || end < l) {
return; // No overlap
}

if (l <= start && end <= r) {
tree[node] += (end - start + 1) * val;
if (start != end) {
lazy[2 * node] += val;
lazy[2 * node + 1] += val;
}
return;
}

int mid = (start + end) / 2;
updateRange(2 * node, start, mid, l, r, val);
updateRange(2 * node + 1, mid + 1, end, l, r, val);
tree[node] = tree[2 * node] + tree[2 * node + 1];
}

int queryRange(int node, int start, int end, int l, int r) {
if (lazy[node] != 0) {
tree[node] += (end - start + 1) * lazy[node];
if (start != end) {
lazy[2 * node] += lazy[node];
lazy[2 * node + 1] += lazy[node];
}
lazy[node] = 0;
}

if (r < start || end < l) {
return 0; // No overlap
}

if (l <= start && end <= r) {
return tree[node]; // Complete overlap
}

int mid = (start + end) / 2;
int leftSum = queryRange(2 * node, start, mid, l, r);
int rightSum = queryRange(2 * node + 1, mid + 1, end, l, r);
return leftSum + rightSum;
}
};

Lazy Propagation

Lazy Propagation for Range Updates

// Range update: add val to all elements in [l, r]
void lazyUpdate(int node, int start, int end, int l, int r, int val) {
// Propagate lazy value if exists
if (lazy[node] != 0) {
tree[node] += (end - start + 1) * lazy[node];
if (start != end) {
lazy[2 * node] += lazy[node];
lazy[2 * node + 1] += lazy[node];
}
lazy[node] = 0;
}

// No overlap
if (r < start || end < l) {
return;
}

// Complete overlap
if (l <= start && end <= r) {
tree[node] += (end - start + 1) * val;
if (start != end) {
lazy[2 * node] += val;
lazy[2 * node + 1] += val;
}
return;
}

// Partial overlap
int mid = (start + end) / 2;
lazyUpdate(2 * node, start, mid, l, r, val);
lazyUpdate(2 * node + 1, mid + 1, end, l, r, val);
tree[node] = tree[2 * node] + tree[2 * node + 1];
}

Lazy Propagation for Range Set

// Range set: set all elements in [l, r] to val
void lazySet(int node, int start, int end, int l, int r, int val) {
if (lazy[node] != -1) {
tree[node] = (end - start + 1) * lazy[node];
if (start != end) {
lazy[2 * node] = lazy[node];
lazy[2 * node + 1] = lazy[node];
}
lazy[node] = -1;
}

if (r < start || end < l) {
return;
}

if (l <= start && end <= r) {
tree[node] = (end - start + 1) * val;
if (start != end) {
lazy[2 * node] = val;
lazy[2 * node + 1] = val;
}
return;
}

int mid = (start + end) / 2;
lazySet(2 * node, start, mid, l, r, val);
lazySet(2 * node + 1, mid + 1, end, l, r, val);
tree[node] = tree[2 * node] + tree[2 * node + 1];
}

2D Segment Trees

2D Range Sum Query

class SegmentTree2D {
private:
vector<vector<int>> tree;
vector<vector<int>> arr;
int n, m;

public:
SegmentTree2D(vector<vector<int>>& input) {
arr = input;
n = input.size();
m = input[0].size();
tree.resize(4 * n, vector<int>(4 * m));
build(1, 0, n - 1, 1, 0, m - 1);
}

void build(int nodeX, int startX, int endX, int nodeY, int startY, int endY) {
if (startX == endX && startY == endY) {
tree[nodeX][nodeY] = arr[startX][startY];
} else if (startX == endX) {
int midY = (startY + endY) / 2;
build(nodeX, startX, endX, 2 * nodeY, startY, midY);
build(nodeX, startX, endX, 2 * nodeY + 1, midY + 1, endY);
tree[nodeX][nodeY] = tree[nodeX][2 * nodeY] + tree[nodeX][2 * nodeY + 1];
} else {
int midX = (startX + endX) / 2;
build(2 * nodeX, startX, midX, nodeY, startY, endY);
build(2 * nodeX + 1, midX + 1, endX, nodeY, startY, endY);
tree[nodeX][nodeY] = tree[2 * nodeX][nodeY] + tree[2 * nodeX + 1][nodeY];
}
}

int query(int nodeX, int startX, int endX, int nodeY, int startY, int endY,
int x1, int y1, int x2, int y2) {
if (x2 < startX || endX < x1 || y2 < startY || endY < y1) {
return 0; // No overlap
}

if (x1 <= startX && endX <= x2 && y1 <= startY && endY <= y2) {
return tree[nodeX][nodeY]; // Complete overlap
}

if (startX == endX) {
int midY = (startY + endY) / 2;
int leftSum = query(nodeX, startX, endX, 2 * nodeY, startY, midY, x1, y1, x2, y2);
int rightSum = query(nodeX, startX, endX, 2 * nodeY + 1, midY + 1, endY, x1, y1, x2, y2);
return leftSum + rightSum;
} else {
int midX = (startX + endX) / 2;
int leftSum = query(2 * nodeX, startX, midX, nodeY, startY, endY, x1, y1, x2, y2);
int rightSum = query(2 * nodeX + 1, midX + 1, endX, nodeY, startY, endY, x1, y1, x2, y2);
return leftSum + rightSum;
}
}
};

Persistent Segment Trees

Persistent Segment Tree for Range Queries

struct PersistentNode {
int value;
PersistentNode* left;
PersistentNode* right;

PersistentNode(int val = 0) : value(val), left(nullptr), right(nullptr) {}
};

class PersistentSegmentTree {
private:
vector<PersistentNode*> versions;
int n;

public:
PersistentSegmentTree(vector<int>& arr) {
n = arr.size();
versions.push_back(build(0, n - 1, arr));
}

PersistentNode* build(int start, int end, vector<int>& arr) {
PersistentNode* node = new PersistentNode();

if (start == end) {
node->value = arr[start];
} else {
int mid = (start + end) / 2;
node->left = build(start, mid, arr);
node->right = build(mid + 1, end, arr);
node->value = node->left->value + node->right->value;
}

return node;
}

PersistentNode* update(int version, int idx, int val) {
return update(versions[version], 0, n - 1, idx, val);
}

PersistentNode* update(PersistentNode* node, int start, int end, int idx, int val) {
PersistentNode* newNode = new PersistentNode();

if (start == end) {
newNode->value = val;
} else {
int mid = (start + end) / 2;
if (idx <= mid) {
newNode->left = update(node->left, start, mid, idx, val);
newNode->right = node->right;
} else {
newNode->left = node->left;
newNode->right = update(node->right, mid + 1, end, idx, val);
}
newNode->value = newNode->left->value + newNode->right->value;
}

return newNode;
}

int query(int version, int l, int r) {
return query(versions[version], 0, n - 1, l, r);
}

int query(PersistentNode* node, int start, int end, int l, int r) {
if (r < start || end < l) {
return 0;
}

if (l <= start && end <= r) {
return node->value;
}

int mid = (start + end) / 2;
int leftSum = query(node->left, start, mid, l, r);
int rightSum = query(node->right, mid + 1, end, l, r);
return leftSum + rightSum;
}
};

Performance Analysis

Time Complexity

  • Build: O(n)
  • Query: O(log n)
  • Update: O(log n)
  • Range Update: O(log n) with lazy propagation
  • Space: O(n)

Space Complexity

  • Basic Segment Tree: O(4n)
  • Lazy Segment Tree: O(4n)
  • 2D Segment Tree: O(4n × 4m)
  • Persistent Segment Tree: O(n log n) per version

Common Patterns

  1. Range sum queries with point updates
  2. Range minimum/maximum queries
  3. Range updates with lazy propagation
  4. 2D range queries for matrix problems
  5. Persistent queries for historical data

Applications

  • Range queries: Sum, min, max in ranges
  • Range updates: Add, set values in ranges
  • 2D problems: Matrix range queries
  • Historical queries: Persistent segment trees
  • Competitive programming: Efficient range operations