『学习笔记』线段树

线段树和树状数组都是用来优化序列操作的数据结构。

线段树理解容易,常数大,解决问题范围广;树状数组理解比较困难,常数非常小,能解决的问题就没有线段树多了,可以说树状数组能解决的问题是线段树能解决的问题的子集。

线段树基本概念

线段树是一个二叉树,每个节点表示一个区间。

对于任意节点,要么是叶子节点,要么两个儿子都存在。

它可以快速在序列上修改及查询元素,可以是区间修改或查询。每次修改或查询的时间复杂度为 \(\mathcal{O}(\log n)\)。在使用之前,还需花费 \(\mathcal{O}(n)\) 的时间建树。

那么每个节点存什么?

  • 如果是叶子节点,就存需要执行操作的序列的对应项。具体是哪项下面再说。
  • 否则,就存这个节点的左右儿子之和或最小值、最大值、乘积等等。以求和为例,计算公式为 \(t_x=t_{\operatorname{left\_son}(x)}+t_{\operatorname{right\_son}(x)}\)这个节点的值就是这个节点表示的区间之和

为方便表示,本文中的 ls(x) 均代表 \(x\) 节点的左儿子节点,rs(x) 同理。

每个节点的儿子表示的区间都是当前节点区间的一半,左儿子表示的是 \(\left[l,\left\lfloor \dfrac{l+r}{2} \right\rfloor\right]\),右儿子表示的是 \(\left[\left\lfloor \dfrac{l+r}{2} \right\rfloor+1,r\right]\)

例如,要使用一个长度为 \(8\) 的序列 \(a=[1,1,4,5,1,4,1,9]\) 构造一棵线段树,那么这棵线段树长这样:

各个叶子节点的值都是根据 \(t_x=t_{\operatorname{ls}(x)}+t_{\operatorname{rs}(x)}\) 来计算的。

图中每个节点上面写的是这个节点包含的区间,下面写的是各个节点的值。

非叶子节点的值的计算过程也写了上去。

可以发现,表示第 \(i\) 个数的叶子节点的值就是 \(a_i\)

线段树差不多就长这样子,下面来看详细的操作过程。

普通线段树(不支持区间修改)

如何存储

我们可以使用二叉堆的方式存储:根节点的位置为 \(1\),每个节点的左右儿子的位置分别为 \(i \times 2\)\(i \times 2+1\)

也就是说,你遍历存储这棵树的数组,和层次遍历这棵树一样。

若有空缺的位置,需要留着。

因为存储的节点除了最后一层还有许多个节点,所以数组长度要比 \(n\) 大。

有人计算过,存树的数组需要开到 \(4n\) 才行。

树的节点结构体定义如下:

struct node{
    int l,r; // 表示区间
    T v; // 当前值
}t[N<<2]; // 线段树存储数组

其中的 T 表示线段树需要维护的值的类型,下文也一样,就不多说了。

别问为啥这么写,我写数据结构都喜欢封装成一个类用,用起来爽,里面就比不用类要多一点东西了。

位运算优化

应该很容易猜到,在操作过程中对 \(i \times 2\)\(i \times 2+1\) 的计算非常多。然而直接用乘号就有点慢了。

所以就要使用我们的卡常神器位运算了!

首先看 \(i \times 2\),没什么好说的。

众所周知,想让一个数乘 \(2^{x}\),只需使其左移 \(x\) 位即可。所以 \(i \times 2\) 就是 \(i\) 左移 \(1\) 位。

右儿子的运算又多了个加一,咋整?直接加?那是绝对不可能的!

一个数左移一位后,最后一位一定为 \(0\)。要加一,就是让它变成 \(1\)

于是就可以——将这个数或上 \(1\)

这样就让计算速度快起来了一点点嘛!

于是,可以定义两个函数:

inline int ls(int rt){return rt<<1;} // 左儿子
inline int rs(int rt){return rt<<1|1;} // 右儿子

前面最好加上 inline,防止你算一下儿子在哪都要再丢一个东西到栈里。

还有几个小优化,例如计算一个区间的中间分界点 \(mid\) 时,也可以用位运算来实现除二操作。

还有一个微不足道的,就是计算 \(4n\) 时可以直接左移 \(2\) 位。

废话多不多?

建树

呵呵,正文终于开始了。

上面提到过,建树要用 \(\mathcal{O}(n)\) 的时间复杂度进行。因为有 \(n\) 个元素。

使用深度优先搜索的方式来遍历整棵树。遍历的中间,为各个节点的 \(l\)\(r\) 赋值。

遍历到叶子节点(当前的 \(l=r\))时,这个叶子节点的值就应该是 \(a_l\)

两个儿子都遍历过后,需要通过已经处理好的儿子实时计算当前节点的值。

流程大概是这样的:

为方便起见,我们定义一个函数 pushup 用来计算当前节点的值。

inline void pushup(int rt){
    t[rt].v=t[ls(rt)].v+t[rs(rt)].v; // 计算当前节点的值
}

通过修改 pushup 函数,可以直接修改线段树维护内容。

例如,将其修改为维护最大值的线段树:

inline void pushup(int rt){
    t[rt].v=max(t[ls(rt)].v,t[rs(rt)].v);
}

看代码吧!还是代码形象一点:

void build(int rt,int l,int r){ // rt 表示当前节点,l 和 r 表示当前节点表示的区间
    t[rt].l=l,t[rt].r=r; // 首先的一步就是指定当前节点表示的区间范围
    if(l==r){ // 叶子节点情况
        t[rt].v=a[l]; // 为叶子节点赋值
        return; // 碰到叶子节点了就要回溯了
    }
    int mid=l+r>>1; // 计算区间分界点
    build(ls(rt),l,mid); // 递归遍历左儿子
    build(rs(rt),mid+1,r); // 递归遍历右儿子
    pushup(rt); // 计算当前节点值
}

应该很好理解,就是常数...

单点修改

废话了那么多,终于开始说操作了...

单点修改,就是要修改数列中的一个数。

那么在线段树中,就是修改其中一个叶子节点,我们需要修改某个叶子节点后维护整棵线段树,使其还是保持原来的特性(非叶子节点等于两个儿子的和等特性)。

例如,要修改下标为 \(6\) 的数为 \(8\)

从根节点一直向下找,查找要修改的叶子节点。

若当前搜索的的节点不是叶子节点,那么就需要判断需要修改的叶子节点在左儿子里还是右儿子里:

  1. \(mid \gets \left\lfloor \dfrac{l+r}{2} \right\rfloor\)
  2. 若下标 \(idx \leq mid\),则说明在左儿子里,向左儿子中搜索。
  3. 否则,就去右儿子。

代码如下:

int mid=t[rt].l+t[rt].r>>1; // 找中间点
if(idx<=mid) update(ls(rt),idx,v); // 进左儿子
else update(rs(rt),idx,v); // 右儿子

最终一定会找到一个叶子节点,它就是我们需要修改的。

单纯修改叶子节点会破坏整棵线段树的平衡,所以回溯时需要更新查找需要更改的叶子节点时经过的节点。

在函数末尾加上一句 pushup(rt) 即可。

完整代码:

// rt 是当前节点,idx 是需要修改的数的下标,v 是要替换的数(或累加的数)
void update(int rt,int idx,T v){
    if(t[rt].l==t[rt].r){ // 找到叶子节点的情况
        t[rt].v=v; // 修改
        return; // 回溯
    }
    int mid=t[rt].l+t[rt].r>>1;
    if(idx<=mid) update(ls(rt),idx,v);
    else update(rs(rt),idx,v);
    pushup(rt); // 找到叶子节点后需要将路径上的所有节点都更新一下,从下向上更新
}

很容易看出来,时间复杂度是 \(\mathcal{O}(\log n)\)。别看比暴力还差,区间查询可是 \(\mathcal{O}(\log n)\) 的。

单点查询

没什么好说的,就是从一棵树上找到叶子节点,return 就是了。

这个应该看代码就够了。

T query(int rt,int idx){ // 参数就不多说了
    if(t[rt].l==t[rt].r){ // 找到目标
        return t[rt].v;
    }
    int mid=t[rt].l+t[rt].r>>1;
    if(idx<=mid) return query(ls(rt),idx); // 在左儿子中
    else return query(rs(rt),idx); // 右儿子
    // 这里不需要 pushup,因为没有任何修改
}

区间查询

我们之所以维护整棵线段树就是为了使这个操作的时间复杂度变为 \(\mathcal{O}(\log n)\)

暴力查询时是一个一个累加,但有了线段树就不一样了。

线段树的节点除了叶子节点都存储的是一个区间的和,若某个节点表示的区间在查询区间之内,那么就可以 \(\mathcal{O}(1)\) 地累加出这个节点表示的区间的和。

那如果当前节点表示的区间和查询区间有交集,但并不是查询区间的子集,咋办?

直接看看左右儿子表示的区间是否与查询区间有交集,如果有,则进入相应的儿子查询(两个儿子随便去,但没有都不去的情况,那样当前节点表示的区间要么是查询区间的子集,要么就与查询区间没关系)。

好像有点不好理解...看图吧。

应该步骤写的很清楚了,可以通过代码进一步理解。

T query(int rt,int l,int r){ // l 和 r 表示查询区间!不是当前节点表示区间!
    if(l<=t[rt].l && t[rt].r<=r){ // 刚好是查询区间的子集
        return t[rt].v; // 直接返回
    }
    T res=0; // 因为左右儿子都可能去,所以定义一个变量累加
    int mid=t[rt].l+t[rt].r>>1;
    if(l<=mid) res+=query(ls(rt),l,r); // 若查询区间左端点在左儿子右端点之前,则表示左儿子包含
    if(r>mid) res+=query(rs(rt),l,r); // 查询区间右端点在右儿子左端点之后,同上
    return res;
}

代码也不长,应该挺好理解吧?

下面看一道例题。

P3374 【模板】树状数组 1

题目大意

给定一个长度为 \(n\) 的序列 \(a\)\(m\) 个操作,每次操作包含 \(3\) 个整数:

  • 1 x k:将第 \(x\) 个数加上 \(k\)
  • 2 x y:查询区间 \([x,y]\) 内每个数的和。

思路

虽然是树状数组题,但拿来写线段树也是不错的选择。

这题要写的模板就是单点查询区间修改的模板,我这给出的代码是我封装好的线段树类。

其实不用看类中那些东西,就看那三个函数和私有的几个函数就行了。

应该不用写注释吧(

代码

#include <iostream>
using namespace std;
template<typename T=int>
inline T read(){
    T X=0; bool flag=1; char ch=getchar();
    while(ch<'0' || ch>'9'){if(ch=='-') flag=0; ch=getchar();}
    while(ch>='0' && ch<='9') X=(X<<1)+(X<<3)+ch-'0',ch=getchar();
    if(flag) return X;
    return ~(X-1);
}

template<typename T=int>
inline void write(T X){
    if(X<0) putchar('-'),X=~(X-1);
    T s[20],top=0;
    while(X) s[++top]=X%10,X/=10;
    if(!top) s[++top]=0;
    while(top) putchar(s[top--]+'0');
    putchar('\n');
}

const int N=5e5+5;
int n,m,a[N],op,x,y;

template<class T=long long>
class SgT{
    public:
        SgT(){
            a_res=new int[N];
            for(int i=0; i<N; i++){
                a_res[i]=0;
            }
            a=a_res;
        }
        SgT(int rt,int l,int r,int *_a=nullptr):a(_a==nullptr ? a_res : _a){
            build(rt,l,r);
        }
        ~SgT(){
            delete[] a_res;
        }
        void build(int rt,int l,int r){
            t[rt].l=l,t[rt].r=r;
            if(l==r){
                t[rt].v=a[l];
                return;
            }
            int mid=l+r>>1;
            build(ls(rt),l,mid);
            build(rs(rt),mid+1,r);
            pushup(rt);
        }
        void update(int rt,int idx,T v){
            if(t[rt].l==t[rt].r){
                t[rt].v+=v;
                return;
            }
            int mid=t[rt].l+t[rt].r>>1;
            if(t[rt].l<=mid) update(ls(rt),idx,v);
            else update(rs(rt),idx,v);
            pushup(rt);
        }
        T query(int rt,int l,int r){
            if(l<=t[rt].l && t[rt].r<=r){
                return t[rt].v;
            }
            T res=0;
            int mid=t[rt].l+t[rt].r>>1;
            if(l<=mid) res+=query(ls(rt),l,r);
            if(r<mid) res+=query(rs(rt),l,r);
            return res;
        }
    private:
        int *a,*a_res;
        struct node{
            int l,r;
            T v;
        }t[N<<2];
        inline int ls(int rt){return rt<<1;}
        inline int rs(int rt){return rt<<1|1;}
        inline void pushup(int rt){
            t[rt].v=t[ls(rt)].v+t[rs(rt)].v;
        }
};

int main(){
    n=read(),m=read();
    for(int i=1; i<=n; i++){
        a[i]=read();
    }
    SgT t(1,1,n,a);
    while(m--){
        op=read(),x=read(),y=read();
        if(op==1){
            t.update(1,x,y);
        }else{
            write(t.query(1,x,y));
        }
    }
    return 0;
}

区间修改

说了那么多,就差你一个区间修改了。

如果你直接用单点修改的方法一个一个改,那么时间复杂度就变成 \(\mathcal{O}(n \log n)\) 了,比暴力还差。

那我们可不可以参考区间查询的思想呢?一次修改一个区间?那就需要一个叫懒标记的东西了。

懒标记

我们给节点的结构体加一个变量,叫 \(tag\),懒标记的意思。它表示这个节点之下的所有节点的 \(v\) 都需要加上这个 \(tag\)

这样的话,一次修改一个区间就能实现了:若需修改区间包含某个节点表示的区间,直接将这个节点的 \(tag\) 加上需要增加的值。

可以这样理解懒标记:放寒假了,老师每过一段时间给你布置一次作业(修改一次),你却只是记住有哪些作业(修改懒标记),在开学时(查询)才写(将标记下传)。

除了查询,修改时也需要下传懒标记,节点后代修改(或查询)时需要。

接下来说说如何下传懒标记:

首先一步,就是将懒标记给左右儿子都加上。

还需要修改两个儿子的值,都是修改成儿子表示的区间长度乘上父亲节点的懒标记。因为儿子包含的每一个数都要加上父节点的懒标记,所以要将懒标记乘上长度。

我们将下传懒标记的函数定义为 pushdown()

inline void pushdown(int rt){
    t[ls(rt)].tag+=t[rt].tag; // 懒标记传下去
    t[ls(rt)].v+=t[rt].tag*(t[ls(rt)].r-t[ls(rt)].l+1); // 修改值
    // 右儿子同上
    t[rs(rt)].tag+=t[rt].tag;
    t[rs(rt)].v+=t[rt].tag*(t[rs(rt)].r-t[rs(rt)].l+1);
    t[rt].tag=0; // 记得将父节点的懒标记置 0
}

P3372 【模板】线段树 1

题目大意

需要维护一个序列,支持区间改查。

思路

没别的,看代码就行了。

顺便熟悉一下区间改查线段树。

代码

#include <iostream>
using namespace std;
template<typename T=int>
inline T read(){
    T X=0; bool flag=1; char ch=getchar();
    while(ch<'0' || ch>'9'){if(ch=='-') flag=0; ch=getchar();}
    while(ch>='0' && ch<='9') X=(X<<1)+(X<<3)+ch-'0',ch=getchar();
    if(flag) return X;
    return ~(X-1);
}

template<typename T=int>
inline void write(T X){
    if(X<0) putchar('-'),X=~(X-1);
    T s[20],top=0;
    while(X) s[++top]=X%10,X/=10;
    if(!top) s[++top]=0;
    while(top) putchar(s[top--]+'0');
    putchar('\n');
}

const int N=1e5+5;
int n,m,a[N],op,x,y,k;

template<class T=long long>
class SgT{
    public:
        SgT(int rt=-1,int l=0,int r=0,int *_a=nullptr):a_res(new int[N]),a(_a==nullptr ? a_res : _a){
            for(int i=0; i<N; i++){
                a_res[i]=0;
            }
            if(rt!=-1) build(rt,l,r);
        }
        ~SgT(){
            delete[] a_res;
        }
        void build(int rt,int l,int r){
            t[rt].l=l,t[rt].r=r;
            t[rt].tag=0;
            if(l==r){
                t[rt].v=a[l];
                return;
            }
            int mid=l+r>>1;
            build(ls(rt),l,mid);
            build(rs(rt),mid+1,r);
            pushup(rt);
        }
        void update(int rt,int l,int r,T v){
            if(l<=t[rt].l && t[rt].r<=r){
                t[rt].v+=v*(t[rt].r-t[rt].l+1);
                t[rt].tag+=v;
                return;
            }
            pushdown(rt);
            int mid=t[rt].l+t[rt].r>>1;
            if(l<=mid) update(ls(rt),l,r,v);
            if(r>mid) update(rs(rt),l,r,v);
            pushup(rt);
        }
        T query(int rt,int l,int r){
            if(l<=t[rt].l && t[rt].r<=r){
                return t[rt].v;
            }
            pushdown(rt);
            T res=0;
            int mid=t[rt].l+t[rt].r>>1;
            if(l<=mid) res+=query(ls(rt),l,r);
            if(r>mid) res+=query(rs(rt),l,r);
            return res;
        }
    private:
        int *a,*a_res;
        struct node{
            int l,r;
            T v,tag;
        }t[N<<2];
        inline int ls(int rt){return rt<<1;}
        inline int rs(int rt){return rt<<1|1;}
        inline void pushup(int rt){
            t[rt].v=t[ls(rt)].v+t[rs(rt)].v;
        }
        inline void pushdown(int rt){
            t[ls(rt)].tag+=t[rt].tag;
            t[ls(rt)].v+=t[rt].tag*(t[ls(rt)].r-t[ls(rt)].l+1);
            t[rs(rt)].tag+=t[rt].tag;
            t[rs(rt)].v+=t[rt].tag*(t[rs(rt)].r-t[rs(rt)].l+1);
            t[rt].tag=0;
        }
};

int main(){
    n=read(),m=read();
    for(int i=1; i<=n; i++){
        a[i]=read();
    }
    SgT t(1,1,n,a);
    while(m--){
        op=read();
        if(op==1){
            x=read(),y=read(),k=read();
            t.update(1,x,y,k);
        }else{
            x=read(),y=read();
            write(t.query(1,x,y));
        }
    }
    return 0;
}

推荐题单

暂时就学了这么点线段树。

从易到难排序。

要是能全刷完,那一定是线段树大神了。反正我是刷不完的。

posted @ 2022-03-09 21:32  仙山有茗  阅读(70)  评论(0编辑  收藏  举报