Fork me on GitHub

【数据结构】线段树详解

原文链接:https://www.cnblogs.com/ctjcalc/p/post2.html

线段树是一种高效的数据结构,可以在\(O(nlog_{2}n)\)的时间内查询区间最值或区间和,解决动态的RMQ问题,并且可以为一些算法进行优化,如Dijkstra最短路、扫描线等。

实现原理

线段树是一种基于分治算法的二叉树。每个结点维护一个区间,以及在该区间内的数据信息。在当前结点$x$,它的区间是$[left,right]$,则它的两个子结点的区间分别为$[left,mid],[mid+1,right]$,$mid=\frac{left+right}{2}$。由于采用了分治的思想,在进行操作时每一层至多访问两个结点,极大优化了效率。

基本实现

线段树结点

根据定义,我们可以得出线段树结点的结构:

template <typename T, int Lim>
class SegmentTree 
{
public:
    // ...
private:
    struct Node 
    {
        T Sum;
    } Tree[Lim];
    int Size;
	// ...
};
这里使用了类体和模板,可以适用于更多类型。

Node结构体就是线段树的结点,其中的Sum是该结点维护的区间和。

单点修改

线段树的基本操作就是单点修改。根据分治思想,我们访问结点时,如果已经到了要修改的结点(区间长度为1且已经到达目标位置),就将其值修改,否则,向它的子结点递归,在回溯时随便更新自己结点的值。

首先看一个pushup函数,用于更新当前结点的值。

void pushup(int root) { // 更新区间和
    Tree[root].Sum = Tree[root << 1].Sum + Tree[root << 1 | 1].Sum; 
}

这个函数把子结点的值累加到当前结点上。由于使用了数组,子结点可以不用进行特判是否存在,直接累加即可。

递归过程中,区间会不断缩小,最后长度会变为\(1\),即为一个点,如果这个点和要修改的点重合,就可以修改了。时间复杂度为\(O(log_{2}n)\)

void modify(int root, int left, int right, int pos, T key){
	if(right < pos || left > pos) // 与当前区间无交集,返回
		return;
	if(left == right && left == pos){ // 区间变为点且与查询位置重合,直接修改
		Tree[left] += key;
		return;
	}
	int mid = left + right >> 1;
	modify(root << 1, left, mid, pos, key);
	modify(root << 1 | 1, mid + 1, right, pos, key);
	pushup(root); // 更新区间和
}

构建

现在有了单点修改,但对于一个初始好的序列,单点修改显然不太方便,并且时间复杂度会过大(然而有时还是可以过)。所以我们也有一种方法快速构建线段树,并用一个序列初始化。

void build(int root, int left, int right, T arr[]) {
    if (left == right) { // 区间变为点,保存值
        Tree[root].Sum = arr[left];
        return;
    }
    int mid = left + right >> 1;
    build(root << 1, left, mid, arr);
    build(root << 1 | 1, mid + 1, right, arr);
    pushup(root); // 更新区间和
}
如果区间长度为$1$,则直接初始化,否则进行递归,回溯时`pushup`。时间复杂度为$O(n)$。

单点查询

只要在递归是进行区间的判断,如果查询的位置\(pos \leq mid\),则向左子树递归,否则向右子树递归。

T query(int root, int left, int right, int pos) {
    if(left == right && left == pos) { // 区间变为点且与查询位置重合,直接返回
        return Tree[root].Sum;
    }
    int mid = left + right >> 1;
    if(pos <= mid) return query(root << 1, left, mid, pos); // 在左边
    else return query(root << 1 | 1, mid + 1, right, pos); // 在右边
}

扩展实现

区间修改

如果使用单点修改来实现区间修改,时间复杂度就会达到\(O((right-left+1)nlog_{2}n)\)。所以,我们需要引入懒标记。在结点上加上懒标记就可以快速修改区间。

举个例子:

我们有一个维护了区间为$[1,8]$的线段树,我们要给这个区间加上数$v$,有两种方法:
  • 迭代区间\([1,8]\),单点修改每个结点。
  • 在区间\([1,8]\)的结点上加上\(v\),修改区间\([1,8]\)

很显然,第二种效率更高。这就是懒标记

懒标记可以是加的标记,也可以是乘的标记,只要满足分配律,就可以使用懒标记。但是如果是开根号等等的操作就不可以,因为\(\sqrt{a+b} \neq \sqrt{a} + \sqrt{b}\)

对于目前的线段树,我们需要修改结点的结构体:

template <typename T, int Lim>
class SegmentTree 
{
public:
    // ...
private:
    struct Node 
    {
        T Add, Sum;
    } Tree[Lim];
    int Size;
	// ...
};

修改可以分为这几个情况:

  • 如果修改区间完全覆盖当前区间,就在当前区间的懒标记上加上值,并计算出区间和,然后返回。
  • 如果修改区间左端点在当前区间中点的左边,则递归修改左区间。
  • 右边类似。

然后更新当前结点。

但这里有一个问题,如果我们修改时经过一个已经打过懒标记的结点,就无法正确更新区间和。这时候就需要标记下传了。在修改区间之前,首先标记下传,使子结点的区间和也更新,这样就能得到正确答案了。

void update(int root, int left, int right, T key) { // 添加懒标记
    Tree[root].Add += key; // 把懒标记加上key
    Tree[root].Sum += key * (right - left + 1); // 更新区间和
}

void pushdown(int root, int left, int right) {
    if (Tree[root].Add == 0)
        return;
    int mid = left + right >> 1;
    update(root << 1, left, mid, Tree[root].Add); // 下传标记到左右子结点
    update(root << 1 | 1, mid + 1, right, Tree[root].Add);
    Tree[root].Add = 0; // 清楚懒标记
}
再根据上面讲解,我们可以得出以下区间修改代码。
void modify(int root, int left, int right, int mleft, int mright, T key) {
    if (mleft <= left && right <= mright) { // 完全覆盖当前区间
        update(root, left, right, key); // 直接添加懒标记
        return;
    }
    int mid = left + right >> 1;
    pushdown(root, left, right);
    if (mleft <= mid) // 如果修改区间被左边包含,则往左边区间递归
        modify(root << 1, left, mid, mleft, mright, key);
    if (mid < mright) // 如果修改区间被右边包含,则往右边区间递归
        modify(root << 1 | 1, mid + 1, right, mleft, mright, key);
    pushup(root); // 更新
}

区间查询

与区间查询类似,也需要进行标记下传。

查询可以也可以分为这三个情况:

  • 如果查询区间完全覆盖当前区间,直接返回区间和。
  • 如果修改区间左端点在当前区间中点的左边,则递归查询左区间和。
  • 右边类似。

最后累加左右子树区间和。

T query(int root, int left, int right, int qleft, int qright) {
    if (qleft <= left && right <= qright)
        return Tree[root].Sum;
    int mid = left + right >> 1;
    pushdown(root, left, right);
    T res = 0;
    if (qleft <= mid)
        res += query(root << 1, left, mid, qleft, qright);
    if (mid < qright)
        res += query(root << 1 | 1, mid + 1, right, qleft, qright);
    return res;
}

完整实现

```cpp template class SegmentTree { public: void init(int n) { Size = n; }
void build(T arr[]) { build(1, 1, Size, arr); }

void modify(int left, int right, T key) { modify(1, 1, Size, left, right, key); }

T query(int left, int right) { return query(1, 1, Size, left, right); }

private:
struct Node
{
T Add, Sum;
} Tree[Lim];
int Size;

void pushup(int root) { Tree[root].Sum = Tree[root << 1].Sum + Tree[root << 1 | 1].Sum; }

void update(int root, int left, int right, T key) {
    Tree[root].Add += key;
    Tree[root].Sum += key * (right - left + 1);
}

void pushdown(int root, int left, int right) {
    if (Tree[root].Add == 0)
        return;
    int mid = left + right >> 1;
    update(root << 1, left, mid, Tree[root].Add);
    update(root << 1 | 1, mid + 1, right, Tree[root].Add);
    Tree[root].Add = 0;
}

void build(int root, int left, int right, T arr[]) {
    if (left == right) {
        Tree[root].Sum = arr[left];
        return;
    }
    int mid = left + right >> 1;
    build(root << 1, left, mid, arr);
    build(root << 1 | 1, mid + 1, right, arr);
    pushup(root);
}

void modify(int root, int left, int right, int mleft, int mright, T key) {
    if (mleft <= left && right <= mright) {
        update(root, left, right, key);
        return;
    }
    int mid = left + right >> 1;
    pushdown(root, left, right);
    if (mleft <= mid)
        modify(root << 1, left, mid, mleft, mright, key);
    if (mid > mright)
        modify(root << 1 | 1, mid + 1, right, mleft, mright, key);
    pushup(root);
}

T query(int root, int left, int right, int qleft, int qright) {
    if (qleft <= left && right <= qright)
        return Tree[root].Sum;
    int mid = left + right >> 1;
    pushdown(root, left, right);
    T res = 0;
    if (qleft <= mid)
        res += query(root << 1, left, mid, qleft, qright);
    if (mid< qright)
        res += query(root << 1 | 1, mid + 1, right, qleft, qright);
    return res;
}

};

<div id="copyright">
Copyright © 2019 ctjcalc,转载请注明URL,并给出原文链接,谢谢。
</div>
posted @ 2019-10-25 13:04  ctjcalc  阅读(550)  评论(0编辑  收藏  举报