A Segment Tree maintains a set of intervals, where ends of every interval lie on one of the predetermined points P. Therefore, there are O(P^2) possible intervals. The major use of a segment tree is to do ranged queries; therefore, it's more or less like binary interval trees, but with sparse coordinates.
For each node, this implementation maintains the size of the union of all intervals within the range that this node covers (ooh long sentence). This is handful for calculating the area of the union of rectangles on a 2D plane.
C#
lang:c#
private class SegmentSumTree
{
private int n;
private int[] xs, has, sum;
public SegmentSumTree(int[] x)
{
xs = (int[])x.Clone();
n = xs.Length;
has = new int[(n+1)*4+10];
sum = new int[(n+1)*4+10];
}
// change [x[lo], x[hi]) by delta
public void Change(int lo, int hi, int delta)
{
Change(lo, hi, delta, 0, 0, n);
}
private void Change(int lo, int hi, int delta, int node, int nodeLeft, int nodeRight)
{
lo = Math.Max(lo, nodeLeft);
hi = Math.Min(hi, nodeRight);
if (lo >= hi) return;
// terminal node
if (nodeLeft + 1 == nodeRight)
{
has[node] += delta;
sum[node] = (xs[nodeRight] - xs[nodeLeft]) * Math.Min(1, has[node]);
}
// found a fit
else if (nodeLeft == lo && nodeRight == hi)
{
has[node] += delta;
if (has[node] > 0)
sum[node] = xs[nodeRight] - xs[nodeLeft];
else
sum[node] = sum[node * 2 + 1] + sum[node * 2 + 2];
}
// split away...
else
{
int half = (nodeRight - nodeLeft) / 2;
Change(lo, hi, delta, node * 2 + 1, nodeLeft, nodeLeft + half);
Change(lo, hi, delta, node * 2 + 2, nodeLeft + half, nodeRight);
if (has[node] == 0)
sum[node] = sum[node * 2 + 1] + sum[node * 2 + 2];
}
}
public int All()
{
return sum[0];
}
};
C++
lang:cpp
struct SegmentSumTree
{
int n;
vector<int> has, sum, coords;
SegmentSumTree(const vector<int>& coords) : n(coords.size()), has((n+1)*4), sum((n+1)*4), coords(coords) {}
void change(int lo, int hi, int delta, int node = 0, int nodeLeft = 0, int nodeRight = -1)
{
if(nodeRight == -1) nodeRight = n-1;
lo = max(lo, nodeLeft); hi = min(hi, nodeRight);
if(lo >= hi) return;
if(nodeLeft+1 == nodeRight)
{
has[node] += delta;
sum[node] = min(1, has[node]) * (coords[nodeRight] - coords[nodeLeft]);
}
else if(lo == nodeLeft && hi == nodeRight)
{
has[node] += delta;
sum[node] = max(sum[node*2+1] + sum[node*2+2], min(1, has[node]) * (coords[nodeRight] - coords[nodeLeft]));
}
else
{
int mid = (nodeRight + nodeLeft) / 2;
change(lo, hi, delta, node*2+1, nodeLeft, mid);
change(lo, hi, delta, node*2+2, mid, nodeRight);
sum[node] = max(sum[node*2+1] + sum[node*2+2], min(1, has[node]) * (coords[nodeRight] - coords[nodeLeft]));
}
}
int sumAll() { return sum[0]; }
};


