线段树(segment tree)

  线段树是一种二叉搜索树,它的每一个结点对应着一个区间[L, R],叶子结点对应的区间就是一个单位区间,即L == R。对于一个非叶子结点[L, R],它的左儿子所表示的区间是[L, (L +R)/2],右儿子所代表的的区间是[(L + R) / 2 +1, R]。

  拿一个简单的例子来说,我们需要维护一个数列,每次进行以下两种操作:

  • 修改一个元素
  • 查询一段区间的最大值

这是一道经典的RMQ(range minimum/maximum query,区间最值查询问题)问题,用线段树怎么解决呢?更新是点更新,查询是区间查询。

具体操作如下:

  建树的时候,始终遵循每个结点维护结点所代表的左右端点和该区间的最值,建树的时候如果到叶子结点,那么这个结点的最值就是对应位置的数列的值,否则递归的建立左子树和右子树,然后将当前结点的区间最值设置为自己左子树和右子树最值的较大值。

  先定义线段树的结点:

const int maxn = 10010;

struct Node {
    int l, r, mx;//左右区间端点和最大值 
}tr[maxn<<2];

建树:

 1 void build(int d, int l, int r) {//递归建树 
 2     tr[d].l = l, tr[d].r = r;
 3     if(l == r) {                 //叶子结点 
 4         tr[d].mx = b[l];
 5         return;
 6     }
 7     int mid = (l + r) / 2;
 8     int lc = d * 2;
 9     int rc = d * 2 + 1;
10     build(lc, l ,mid);             //递归构建左子树 
11     build(rc, mid + 1, r);         //递归构建右子树 
12     tr[d].mx = max(tr[lc].mx, tr[rc].mx);//该区间的最值是左子结点和右子结点的最大值 
13 }

  如果是查询操作,从根结点开始查询:如果查询区间在该结点的左子内,则查询左子;如果查询区间在该结点的右子树内,则查询右子树;否则,查询左子树相应区间和右子树相应区间,并将两者的返回值的较大值返回。代码如下:

 1 int query(int d, int l, int r) {
 2     if(tr[d].l == l && tr[d].r == r) {//待查询区间等于当前结点的区间范围 
 3         return tr[d].mx;
 4     }
 5     int mid = (tr[d].l + tr[d].r) / 2;//取当前结点的中点 
 6     int lc = d * 2;
 7     int rc = d * 2 + 1;
 8     if(r <= mid)    return query(lc, l, mid);//查询区间属于当前结点的左子树就查询左子树 
 9     else if(l > mid)    return query(rc, mid + 1, r);//查询区间属于当前结点的右子树就查询右子树
10     else return max(query(lc, l, mid), query(rc, mid + 1, r));//查询区间分布在两侧 
11 }

  如果是修改操作,则从根结点开始修改,一直修改到叶子结点,同时对路径上相应结点的最值进行更新。

 1 void modify(int d, int pos, int v) {//将位置为pos的元素改成v 
 2     if(tr[d].l == tr[d].r && tr[d].l == pos) {//如果当前结点是叶子结点且是该结点 
 3         tr[d].mx = v;
 4         return;
 5     }
 6     int mid = (tr[d].l + tr[d].r) / 2;
 7     int lc = d * 2;
 8     int rc = d * 2 + 1;
 9     if(pos <= mid) modify(lc, pos, v);//如果要修改的位置在当前结点的左子树 
10     else modify(rc, pos, v);          //                        右子树 
11     tr[d].mx = max(tr[lc].mx, tr[rc].mx); 
12 } 

  以上是点更新加上区间查询,运用的时候将元素存入数组b中,建树,直接修改、查询即可。

  明白了基本的原理,下面介绍一种实现起来更简短,使用更方便的写法。

 1 const int INF = 99999999;
 2 int ql, qr;//查询区间 
 3 int query(int o, int L, int R) {
 4     int M = L + (R - L) / 2;
 5     int ans = -INF;
 6     if(ql <= L && R <= qr)    return maxv[o];            //当前结点完全包含在查询区间内 
 7     if(ql <= M)    ans = max(ans, query(o * 2, L, M));    //往左走 
 8     if(M < qr)    ans = max(ans, query(o * 2 + 1, M + 1, R));//往右走 
 9     return ans;
10 } 
11 
12 int p, v;//修改A[p] = v 
13 void update(int o, int L, int R) {
14     int M = L + (R - L) / 2;
15     if(L == R)    maxv[o] = v;
16     else {
17         if(p <= M)    update(o * 2, L, M);
18         else        update(o * 2 + 1, M + 1, R);
19         maxv[o] = max(maxv[o * 2], maxv[o * 2 + 1]);
20     }
21 } 

  使用的时候建树的过程是每次读入一个数,使用update函数更新A[i] = x。然后直接查询、修改即可。

  以上是点更新加上区间查询,如果没有点更新,只是查询某个区间的最值,则直接使用ST算法(简单不易写错)。

  但是通常在题目中会遇到对区间进行更新的操作,比如给出一个n个元素的数组A1,A2,A3...An,你的任务是设计一个数据结构,支持一下两种操作。

  • Add(L,R,v):把AL,AL+1,...,AR的值全部增加v。
  • Query(L,R):计算子序列AL,AL,...AR的元素和、最小值和最大值。

  我们需要在线段树中维护3个信息sum,min,max,分别对应三个查询值。其中如果还是使用sum[o]表示“结点o对应区间中所有数之和”,则add操作最坏情况下会修改所有的sum。解决的办法是把sum[o]的定义改成“如果只执行结点o及其子孙结点中的add操作,结点o对应区间中所有数之和”。信息维护的代码如下:

 1 //维护结点o,对应区间[L,R]
 2 void maintain(int o, int L, int R) {
 3     int lc = o * 2;
 4     int rc = O * 2 + 1;
 5     sumv[o] = minv[o] = maxv[o] = 0;
 6     if(R > L) {//考虑左右子树 
 7         sumv[o] = sumv[lc] + sumv[rc];
 8         minv[o] = min(minv[lc], minv[rc]);
 9         maxv[o] = max(maxv[lc], maxv[rc]);
10     }
11     minv[o] += addv[o];
12     maxv[o] += addv[o];
13     sumv[o] += addv[o] * (R - L + 1);
14 }

  上述维护结点o的maintain函数在递归访问到的结点都需要调用,并且在递归返回后调用。代码如下:

 1 //其中y1,y2表示修改和查询的区间 
 2 void update(int o, int L, int R) {
 3     int lc = o * 2;
 4     int rc = o * 2 + 1;
 5     if(y1 <= L && y2 >= R) {//递归边界 
 6         addv[o] += v; 
 7     } else {
 8         int M = L + (R - L) / 2;
 9         if(y1 <= M)    update(lc, L, M);
10         if(y2 > M)    update(rc, M + 1, R);
11     }
12     maintain(o, L, R);//递归结束后重新计算本结点附加信息 
13 }

  接下来就是查询操作了,基本思路仍然是把查询区间递归分解为若干不相交子区间,把各个子区间的查询结果加以合并,但是需要注意的是每个边界区间的结果不能直接使用,还得考虑祖先结点对它的影响。为了方便,我们在递归查询函数中增加了一个参数,表示当前区间的所有祖先结点add值之和。代码如下:

 1 int _min, _max, _sum;//对应查询结果
 2 void query(int o, int L, int R, int add) {
 3     if(y1 <= L && y2 >= R) {
 4         _sum += sumv[o] + add * (R - L + 1);
 5         _min = min(_min, minv[o] + add);
 6         _max = max(_max, maxv[o] + add);
 7     } else {//递归统计累加参数add 
 8         int M = L + (R - L) / 2;
 9         if(y1 <= M)    query(o * 2, L, M, add + addv[o]);
10         if(y2 > M)  query(o * 2 + 1, M + 1, R, add + addv[o]);
11     }
12 } 

  上述讲解的是区间增减,还有一种情况是区间赋值。即给出一个有n个元素的数组,A1,A2,...,An,你的任务是设计一个数据结构,支持一下两种操作:

  • Set(L, R, v):把AL,AL+1,...AR的值全部修改成v(v>=0)
  • Query(L,R):计算子序列AL,AL,...AR的元素和、最小值和最大值。

  同理我们将set操作也进行分解,记录在结点中,但是出现了一个新的问题,即add操作没有先后的时效性,但是set操作是有的。

  解决的办法是设计一个向下传递函数,用来做一个标记。

  新的修改操作代码如下:

 1 void update(int o, int L, int R) {
 2     int lc = o * 2;
 3     int rc = o * 2 + 1;
 4     if(y1 <= L && y2 >= R) {//递归边界,将set标记修改 
 5         setv[o] = v; 
 6     } else {
 7         pushdown(o);
 8         int M = L + (R - L) / 2;
 9         if(y1 <= M)    update(lc, L, M);    else maintain(lc, L, M);
10         if(y2 > M)    update(rc, M + 1, R);    else maintain(rc, M + 1, R);
11     }
12     maintain(o, L, R);//递归结束后重新计算本结点附加信息 
13 }

其中需要注意的有两个地方,首先是pushdown函数,它的作用就是把set值往下传递。

1 void pushdown(int o) {
2     int lc = o * 2;
3     int rc = o * 2 + 1;
4     if(setv[o] >= 0) {//由于赋的值是大于等于0的,所以>= 0表示有标记 
5         setv[lc] = setv[rc] = setv[o];
6         setv[o] = -1; //清除标记 
7     }
8 }

  另一个值得注意的地方是代码出多了两处maintain的调用。对于本来就要递归访问的子树,递归访问结束之后自然会调用maintain,因此只需要针对不进行递归访问的子树调用maintain即可。

  接下来就是关键的查询问题了,怎么解决任意两个set操作不会存在祖先-后代关系的问题。

  其实我们只需规定在这种情况下,以祖先结点上的操作为准即可,在递归查询的时候,碰到到一个set操作就立即停止即可。代码如下:

 1 void query(int o, int L, int R) {
 2     if(setv[o] >= 0) {               //递归边界1:有set标记 
 3         _sum += setv[o] * (min(R, y2) - max(L, y1) + 1);
 4         _min = min(_min, setv[o]);
 5         _max = max(_max, setv[o]);
 6     } else if(y1 <= L && y2 >= R) {//递归边界2:边界区间 
 7         _sum += sumv[o];           //此区间没有被任何set操作影响 
 8         _min = min(_min, minv[o]);
 9         _max = max(_max, maxv[o]);
10     } else {                       //递归统计 
11         int M = L + (R - L) / 2;
12         if(y1 <= M)    query(o * 2, L, M);
13         if(y2 > M)     query(o * 2 + 1, M + 1, R);
14     }
15 } 

  暂时线段树的讲解就到这里,理解的还不是太透彻,之后会补上几道例题。

  

posted @ 2018-10-06 19:35  Reqaw  阅读(304)  评论(0编辑  收藏  举报