Skip to main content

Segment Tree

Range queries, range updates, lazy propagation, and advanced segment tree techniques

Master the Segment Tree data structure for efficient range queries and updates.

Segment Tree Fundamentals

What is a Segment Tree?

A Segment Tree is a data structure that allows efficient range queries and updates on an array. It's particularly useful for problems involving range sum, range minimum/maximum, and range updates.

Basic Structure

class SegmentTree {
private:
vector<int> tree;
vector<int> arr;
int n;

public:
SegmentTree(vector<int>& input) {
arr = input;
n = input.size();
tree.resize(4 * n);
build(1, 0, n - 1);
}

void build(int node, int start, int end) {
if (start == end) {
tree[node] = arr[start];
} else {
int mid = (start + end) / 2;
build(2 * node, start, mid);
build(2 * node + 1, mid + 1, end);
tree[node] = tree[2 * node] + tree[2 * node + 1];
}
}
};

Range Queries

Range Sum Query

// Query sum in range [l, r]
int query(int node, int start, int end, int l, int r) {
if (r < start || end < l) {
return 0; // No overlap
}
if (l <= start && end <= r) {
return tree[node]; // Complete overlap
}

int mid = (start + end) / 2;
int leftSum = query(2 * node, start, mid, l, r);
int rightSum = query(2 * node + 1, mid + 1, end, l, r);
return leftSum + rightSum;
}

Range Minimum Query

// Query minimum in range [l, r]
int queryMin(int node, int start, int end, int l, int r) {
if (r < start || end < l) {
return INT_MAX; // No overlap
}
if (l <= start && end <= r) {
return tree[node]; // Complete overlap
}

int mid = (start + end) / 2;
int leftMin = queryMin(2 * node, start, mid, l, r);
int rightMin = queryMin(2 * node + 1, mid + 1, end, l, r);
return min(leftMin, rightMin);
}

Range Maximum Query

// Query maximum in range [l, r]
int queryMax(int node, int start, int end, int l, int r) {
if (r < start || end < l) {
return INT_MIN; // No overlap
}
if (l <= start && end <= r) {
return tree[node]; // Complete overlap
}

int mid = (start + end) / 2;
int leftMax = queryMax(2 * node, start, mid, l, r);
int rightMax = queryMax(2 * node + 1, mid + 1, end, l, r);
return max(leftMax, rightMax);
}

Range Updates

Point Update

// Update value at index idx
void update(int node, int start, int end, int idx, int val) {
if (start == end) {
arr[idx] = val;
tree[node] = val;
} else {
int mid = (start + end) / 2;
if (idx <= mid) {
update(2 * node, start, mid, idx, val);
} else {
update(2 * node + 1, mid + 1, end, idx, val);
}
tree[node] = tree[2 * node] + tree[2 * node + 1];
}
}

Range Update with Lazy Propagation

class LazySegmentTree {
private:
vector<int> tree;
vector<int> lazy;
vector<int> arr;
int n;

public:
LazySegmentTree(vector<int>& input) {
arr = input;
n = input.size();
tree.resize(4 * n);
lazy.resize(4 * n);
build(1, 0, n - 1);
}

void build(int node, int start, int end) {
if (start == end) {
tree[node] = arr[start];
} else {
int mid = (start + end) / 2;
build(2 * node, start, mid);
build(2 * node + 1, mid + 1, end);
tree[node] = tree[2 * node] + tree[2 * node + 1];
}
}

void updateRange(int node, int start, int end, int l, int r, int val) {
if (lazy[node] != 0) {
tree[node] += (end - start + 1) * lazy[node];
if (start != end) {
lazy[2 * node] += lazy[node];
lazy[2 * node + 1] += lazy[node];
}
lazy[node] = 0;
}

if (r < start || end < l) {
return; // No overlap
}

if (l <= start && end <= r) {
tree[node] += (end - start + 1) * val;
if (start != end) {
lazy[2 * node] += val;
lazy[2 * node + 1] += val;
}
return;
}

int mid = (start + end) / 2;
updateRange(2 * node, start, mid, l, r, val);
updateRange(2 * node + 1, mid + 1, end, l, r, val);
tree[node] = tree[2 * node] + tree[2 * node + 1];
}

int queryRange(int node, int start, int end, int l, int r) {
if (lazy[node] != 0) {
tree[node] += (end - start + 1) * lazy[node];
if (start != end) {
lazy[2 * node] += lazy[node];
lazy[2 * node + 1] += lazy[node];
}
lazy[node] = 0;
}

if (r < start || end < l) {
return 0; // No overlap
}

if (l <= start && end <= r) {
return tree[node]; // Complete overlap
}

int mid = (start + end) / 2;
int leftSum = queryRange(2 * node, start, mid, l, r);
int rightSum = queryRange(2 * node + 1, mid + 1, end, l, r);
return leftSum + rightSum;
}
};

Lazy Propagation

Lazy Propagation for Range Updates

// Range update: add val to all elements in [l, r]
void lazyUpdate(int node, int start, int end, int l, int r, int val) {
// Propagate lazy value if exists
if (lazy[node] != 0) {
tree[node] += (end - start + 1) * lazy[node];
if (start != end) {
lazy[2 * node] += lazy[node];
lazy[2 * node + 1] += lazy[node];
}
lazy[node] = 0;
}

// No overlap
if (r < start || end < l) {
return;
}

// Complete overlap
if (l <= start && end <= r) {
tree[node] += (end - start + 1) * val;
if (start != end) {
lazy[2 * node] += val;
lazy[2 * node + 1] += val;
}
return;
}

// Partial overlap
int mid = (start + end) / 2;
lazyUpdate(2 * node, start, mid, l, r, val);
lazyUpdate(2 * node + 1, mid + 1, end, l, r, val);
tree[node] = tree[2 * node] + tree[2 * node + 1];
}

Lazy Propagation for Range Set

// Range set: set all elements in [l, r] to val
void lazySet(int node, int start, int end, int l, int r, int val) {
if (lazy[node] != -1) {
tree[node] = (end - start + 1) * lazy[node];
if (start != end) {
lazy[2 * node] = lazy[node];
lazy[2 * node + 1] = lazy[node];
}
lazy[node] = -1;
}

if (r < start || end < l) {
return;
}

if (l <= start && end <= r) {
tree[node] = (end - start + 1) * val;
if (start != end) {
lazy[2 * node] = val;
lazy[2 * node + 1] = val;
}
return;
}

int mid = (start + end) / 2;
lazySet(2 * node, start, mid, l, r, val);
lazySet(2 * node + 1, mid + 1, end, l, r, val);
tree[node] = tree[2 * node] + tree[2 * node + 1];
}

2D Segment Trees

2D Range Sum Query

class SegmentTree2D {
private:
vector<vector<int>> tree;
vector<vector<int>> arr;
int n, m;

public:
SegmentTree2D(vector<vector<int>>& input) {
arr = input;
n = input.size();
m = input[0].size();
tree.resize(4 * n, vector<int>(4 * m));
build(1, 0, n - 1, 1, 0, m - 1);
}

void build(int nodeX, int startX, int endX, int nodeY, int startY, int endY) {
if (startX == endX && startY == endY) {
tree[nodeX][nodeY] = arr[startX][startY];
} else if (startX == endX) {
int midY = (startY + endY) / 2;
build(nodeX, startX, endX, 2 * nodeY, startY, midY);
build(nodeX, startX, endX, 2 * nodeY + 1, midY + 1, endY);
tree[nodeX][nodeY] = tree[nodeX][2 * nodeY] + tree[nodeX][2 * nodeY + 1];
} else {
int midX = (startX + endX) / 2;
build(2 * nodeX, startX, midX, nodeY, startY, endY);
build(2 * nodeX + 1, midX + 1, endX, nodeY, startY, endY);
tree[nodeX][nodeY] = tree[2 * nodeX][nodeY] + tree[2 * nodeX + 1][nodeY];
}
}

int query(int nodeX, int startX, int endX, int nodeY, int startY, int endY,
int x1, int y1, int x2, int y2) {
if (x2 < startX || endX < x1 || y2 < startY || endY < y1) {
return 0; // No overlap
}

if (x1 <= startX && endX <= x2 && y1 <= startY && endY <= y2) {
return tree[nodeX][nodeY]; // Complete overlap
}

if (startX == endX) {
int midY = (startY + endY) / 2;
int leftSum = query(nodeX, startX, endX, 2 * nodeY, startY, midY, x1, y1, x2, y2);
int rightSum = query(nodeX, startX, endX, 2 * nodeY + 1, midY + 1, endY, x1, y1, x2, y2);
return leftSum + rightSum;
} else {
int midX = (startX + endX) / 2;
int leftSum = query(2 * nodeX, startX, midX, nodeY, startY, endY, x1, y1, x2, y2);
int rightSum = query(2 * nodeX + 1, midX + 1, endX, nodeY, startY, endY, x1, y1, x2, y2);
return leftSum + rightSum;
}
}
};

Persistent Segment Trees

Persistent Segment Tree for Range Queries

struct PersistentNode {
int value;
PersistentNode* left;
PersistentNode* right;

PersistentNode(int val = 0) : value(val), left(nullptr), right(nullptr) {}
};

class PersistentSegmentTree {
private:
vector<PersistentNode*> versions;
int n;

public:
PersistentSegmentTree(vector<int>& arr) {
n = arr.size();
versions.push_back(build(0, n - 1, arr));
}

PersistentNode* build(int start, int end, vector<int>& arr) {
PersistentNode* node = new PersistentNode();

if (start == end) {
node->value = arr[start];
} else {
int mid = (start + end) / 2;
node->left = build(start, mid, arr);
node->right = build(mid + 1, end, arr);
node->value = node->left->value + node->right->value;
}

return node;
}

PersistentNode* update(int version, int idx, int val) {
return update(versions[version], 0, n - 1, idx, val);
}

PersistentNode* update(PersistentNode* node, int start, int end, int idx, int val) {
PersistentNode* newNode = new PersistentNode();

if (start == end) {
newNode->value = val;
} else {
int mid = (start + end) / 2;
if (idx <= mid) {
newNode->left = update(node->left, start, mid, idx, val);
newNode->right = node->right;
} else {
newNode->left = node->left;
newNode->right = update(node->right, mid + 1, end, idx, val);
}
newNode->value = newNode->left->value + newNode->right->value;
}

return newNode;
}

int query(int version, int l, int r) {
return query(versions[version], 0, n - 1, l, r);
}

int query(PersistentNode* node, int start, int end, int l, int r) {
if (r < start || end < l) {
return 0;
}

if (l <= start && end <= r) {
return node->value;
}

int mid = (start + end) / 2;
int leftSum = query(node->left, start, mid, l, r);
int rightSum = query(node->right, mid + 1, end, l, r);
return leftSum + rightSum;
}
};

Performance Analysis

Time Complexity

  • Build: O(n)
  • Query: O(log n)
  • Update: O(log n)
  • Range Update: O(log n) with lazy propagation
  • Space: O(n)

Space Complexity

  • Basic Segment Tree: O(4n)
  • Lazy Segment Tree: O(4n)
  • 2D Segment Tree: O(4n × 4m)
  • Persistent Segment Tree: O(n log n) per version

Common Patterns

  1. Range sum queries with point updates
  2. Range minimum/maximum queries
  3. Range updates with lazy propagation
  4. 2D range queries for matrix problems
  5. Persistent queries for historical data

Applications

  • Range queries: Sum, min, max in ranges
  • Range updates: Add, set values in ranges
  • 2D problems: Matrix range queries
  • Historical queries: Persistent segment trees
  • Competitive programming: Efficient range operations