线段树学习笔记

写在前言之前

很早之前我就开始学最基本的线段树操作,最高的只到区间加法。时隔多日才想起来在深入一下,加之很久没有写过线段树的板子,所以写的时候还是比较恶心的

这次主攻的是区间乘法的操作。

 

前言

对于区间查询这一类问题。如果给定的是一个有序的序列,完全可以使用前缀和求解。求解无序的区间查询是比较常用的有ST表和线段树,今天要说的便是线段树这一数据结构

 

0x00

线段树是一棵二叉搜索树,每一个节点都储存有一些信息,通过对这些信息的修改和维护可以做到$O(nlogn)$的时间内建树$+$修改$+$查询。

可能下面这张图能够更加直观的解释线段树是啥

每个点下的红色的字体表示区间的左右端点,每个点里面的数是这个店所代表的区间的和,最下面黄色的店里面的是序列中的元素

这张图这么好看,怎么可能是我画的呢QAQ

 

线段树版本1

0x01

已知一个数列,你需要进行下面两种操作:

  • 将某区间每一个数加上x
  • 求出某区间每一个数的和

在这种情况下,我们要用到最普通的线段树,支持区间加法和区间查询。

线段树将树上的节点都看做一条线段,每个节点上都维护着一些信息

如果上面的题目的话,就需要维护下列的信息

  1. $l$和$r$表示区间(线段)的左端点和右端点
  2. $sum$表示这个区间的元素的总和
  3. $lazytag$,先记住,这东西叫做懒标记,在后面会作出解释

本篇文章全部使用数组来实现,指针党请谅解

0x02

我们先来看如何建立一棵线段树。通过一个$build$函数来实现,这个函数的时间复杂度是$O(nlogn)$

首先从根节点开始向下扩展,显然根节点存储的$l$和$r$应该是$l=1,r=n$。每次扩展时计算一个$mid$值$=(l+r)/2$

这个从$l$到$mid$这段区间放到左儿子中,$mid+1$到$r$放到右儿子里。

如果搜索到$l=r$时。已经到底了,这个区间的值就可以确定了,是第$l$个元素。

然后就可以回溯。在回溯的时候维护区间和。另根节点的$sum$等于左右儿子的$sum$的和。

代码如下

inline void build(int k, int ll, int rr) {
    //k是节点的编号,ll是该节点的左端点,rr是右端点 
    tree[k].l = ll, tree[k].r = rr;
    //赋值 
    if(tree[k].l == tree[k].r) {
        //如果找到了l等于r的情况证明已经到了最底部。可以直接输入 
        scanf("%lld", &tree[k].w);
        return ;
        //回溯 
    }
    long long int mid = (ll+rr)/2;
    build((k<<1), ll, mid); //建立左儿子 
    build((k<<1)+1, mid+1, rr); //建立右儿子 
    tree[k].w = tree[(k<<1)].w+tree[(k<<1)+1].w;
    //维护区间和 
}

 

为了方便大家理解我还录制了GIF给大家看看

0x03

再来说区间加法,这里引入一个前文提到过的概念-----懒标记

顾名思义,懒标记的作用就是懒。它要怎么懒呢?

在进行区间修改的时候我们要减少多余的操作。将一个区间全都加上一个数时。我们只对要用到的区间进行操作。对于那些之后要用到的但现在没用到的区间我们可以先不修改,用$lazytag$存储一个值,什么时候用到什么时候下传给儿子,在一步步下传到要用到的区间。

这个操作用一个$down$函数来实现

inline void down(int k) {
    tree[(k<<1)].f += tree[k].f;
    tree[(k<<1)+1].f += tree[k].f; 
    //更新左右儿子的懒标记 
    tree[(k<<1)].w += tree[k].f*(tree[(k<<1)].r-tree[(k<<1)].l+1);
    tree[(k<<1)+1].w += tree[k].f*(tree[(k<<1)+1].r-tree[(k<<1)+1].l+1);
    //更新左右儿子的区间和 
    tree[k].f = 0;
    //清除父亲结点的懒标记 
}

 

  

然后这个区间加法就只剩下普通的操作了,看下面的区间修改代码

inline void change_interval(int k) {
    if(tree[k].l >= a&&tree[k].r <= b) {
        tree[k].w += (tree[k].r-tree[k].l+1)*y;
        tree[k].f += y;
        //更新当前区间的和还有当前区间的懒标记 
        return ;
    }
    if(tree[k].f) down(k);
    //如果懒标记不为0的话就下传给自己的儿子 
    int mid = (tree[k].l+tree[k].r)/2;
    if(a <= mid) change_interval((k<<1));
    if(b > mid) change_interval((k<<1)+1);
    tree[k].w = tree[(k<<1)].w+tree[(k<<1)+1].w;
    //维护区间和 
}

 

  

0x04

至于区间查询,和区间修改是差不多的

inline void ask_interval(int k) {
    //如果当前的区间被要查询的区间包含的话,直接加到答案中 
    if(tree[k].l >= a&&tree[k].r <= b) {
        ans += tree[k].w;
        return ;
    }
    if(tree[k].f) down(k);
    int mid = (tree[k].l+tree[k].r)/2;
    //判断左右儿子的区间和要查询的区间是否有交集 
    if(a <= mid) ask_interval((k<<1));
    if(b > mid) ask_interval((k<<1)+1);
}

 

0x05

下面放上我的完整的代码

#include <iostream>
#include <cstdio>

using namespace std;

struct node{
    int l, r;
    long long w, f;                        //(l, r)区间,区间和w,懒标记f;
}tree[400001];
long long int ans, y;
int x, n, m;
int a, b;

inline void build_tree(int k, int ll, int rr) {
    tree[k].l = ll, tree[k].r = rr;
    if(tree[k].l == tree[k].r) {
        scanf("%lld", &tree[k].w);
        return ;
    }
    long long int mid = (ll+rr)/2;
    build_tree((k<<1), ll, mid);
    build_tree((k<<1)+1, mid+1, rr);
    tree[k].w = tree[(k<<1)].w+tree[(k<<1)+1].w;
}

inline void down(int k) {
    tree[(k<<1)].f += tree[k].f;
    tree[(k<<1)+1].f += tree[k].f;
    tree[(k<<1)].w += tree[k].f*(tree[(k<<1)].r-tree[(k<<1)].l+1);
    tree[(k<<1)+1].w += tree[k].f*(tree[(k<<1)+1].r-tree[(k<<1)+1].l+1);
    tree[k].f = 0;
}

inline void ask_point(int k) {
    if(tree[k].l == tree[k].r) {
        ans = tree[k].w;
        return ;
    }
    if(tree[k].f) down(k);
    int mid = (tree[k].l+tree[k].r)/2;
    if(x <= mid) ask_point((k<<1));
    else ask_point((k<<1)+1);
}

inline void change_point(int k) {
    if(tree[k].l == tree[k].r) {
        tree[k].w += y;
        return ;
    }
    if(tree[k].f) down(k);
    int mid = (tree[k].l+tree[k].r)/2;
    if(x <= mid) change_point((k<<1));
    else change_point((k<<1)+1);
    tree[k].w = tree[(k<<1)].w+tree[(k<<1)+1].w;
}

inline void ask_interval(int k) {
    if(tree[k].l >= a&&tree[k].r <= b) {
        ans += tree[k].w;
        return ;
    }
    if(tree[k].f) down(k);
    int mid = (tree[k].l+tree[k].r)/2;
    if(a <= mid) ask_interval((k<<1));
    if(b > mid) ask_interval((k<<1)+1);
    // tree[k].w = tree[(k<<1)].w+tree[(k<<1)+1].w;
}

inline void change_interval(int k) {
    if(tree[k].l >= a&&tree[k].r <= b) {
        tree[k].w += (tree[k].r-tree[k].l+1)*y;
        tree[k].f += y;
        return ;
    }
    if(tree[k].f) down(k);
    int mid = (tree[k].l+tree[k].r)/2;
    if(a <= mid) change_interval((k<<1));
    if(b > mid) change_interval((k<<1)+1);
    tree[k].w = tree[(k<<1)].w+tree[(k<<1)+1].w;
}

int main() {
    scanf("%d%d", &n, &m);
    build_tree(1, 1, n);
    for(int i=1; i<=m; i++) {
        int p;
        ans = 0;
        scanf("%d", &p);
        switch(p) {
            /*ask_point*/case 4: {
                scanf("%d", &x);
                ask_point(1);
                printf("%lld\n", ans);
                break;
            }
            /*change_point*/case 3: {
                scanf("%d%d", &x, &y);
                change_point(1);
                break;
            }
            /*ask_interval*/case 2: {
                scanf("%d%d", &a, &b);
                ask_interval(1);
                printf("%lld\n", ans);
                break;
            }
            /*change_interval*/case 1: {
                scanf("%d%d%lld", &a, &b, &y);
                change_interval(1);
                break;
            }
        }
    }
    return 0;
}

 

 

 

0x06

来几个例题

 

线段树版本2

0x00

这个版本的线段树呢,就是加入了更多的区间操作。比如区间乘法,但这些操作大致相同,这里以区间乘法为例进行讲解

像上面的加法有加法标记一样,乘法也有乘法标记。

 

0x01

建树的过程与版本一的建树过程大致相同

这里不再进行详细讲解,只放上代码。唯一要值得注意的是懒标记的初始化,乘法标记要初始化为$1$。

inline void build(int k, int ll, int rr) {
    tree[k].l = ll, tree[k].r = rr;
    tree[k].addtag = 0, tree[k].multag = 1;
    if(tree[k].l == tree[k].r) {
        tree[k].sum = read();
        tree[k].sum %= Mod;
        return ;
    }
    int mid = (tree[k].l + tree[k].r) >> 1;
    build(Lson, tree[k].l, mid);
    build(Rson, mid + 1, tree[k].r);
    tree[k].sum = tree[Lson].sum + tree[Rson].sum;
    tree[k].sum %= Mod;
}

 

  

0x02

懒标记的下传是整个线段树中最核心的部分,一般情况下如果你写的线段树WA掉了,那肯定是你写的懒标记下传函数出了锅

带有区间加法的线段树的下传函数非常复杂。通常情况下我们先下传乘法标记,在下传加法标记。因为先进行乘法,对之后的加法不会产生什么影响,如果先进行加法的话,对之后的乘法就会产生影响。所以我们选择先进行乘法标记的下传。在下传加法标记的同时直接将乘法标记也下传给加法标记。

下面给出$down$函数的代码

inline void pushdown(int k) {
    tree[Lson].multag = tree[k].multag * tree[Lson].multag % Mod;
    tree[Rson].multag = tree[k].multag * tree[Rson].multag % Mod;
    tree[Lson].addtag = tree[Lson].addtag * tree[k].multag % Mod;
    tree[Rson].addtag = tree[Rson].addtag * tree[k].multag % Mod;
    tree[Lson].sum = tree[Lson].sum * tree[k].multag % Mod;
    tree[Rson].sum = tree[Rson].sum * tree[k].multag % Mod;
    tree[Lson].addtag = (tree[k].addtag + tree[Lson].addtag) % Mod;
    tree[Rson].addtag = (tree[k].addtag + tree[Rson].addtag) % Mod;
    int L = (tree[Lson].r - tree[Lson].l + 1);
    int R = (tree[Rson].r - tree[Rson].l + 1);
    tree[Lson].sum = (tree[Lson].sum + L * tree[k].addtag) % Mod;
    tree[Rson].sum = (tree[Rson].sum + R * tree[k].addtag) % Mod;
    tree[k].addtag = 0, tree[k].multag = 1;
}

 

  

0x03

区间乘法和区间加法的更新都和线段树版本1异曲同工

所以不再进行讲解

直接给出代码

区间加法更新

inline void update_add(int k) {
    if(tree[k].l >= x && tree[k].r <= y) {
        tree[k].sum = (tree[k].sum + (tree[k].r - tree[k].l + 1) * z) % Mod;
        tree[k].addtag = (tree[k].addtag + z) % Mod;
        return ;
    }
    pushdown(k);
    int mid = (tree[k].l + tree[k].r) >> 1;
    if(mid >= x) update_add(Lson);
    if(mid < y) update_add(Rson);
    tree[k].sum = (tree[Lson].sum + tree[Rson].sum) % Mod;
}

 

  

区间乘法更新

inline void update_mul(int k) {
    if(tree[k].l >= x && tree[k].r <= y) {
        tree[k].sum = tree[k].sum * z % Mod;
        tree[k].addtag = tree[k].addtag * z % Mod;
        tree[k].multag = tree[k].multag * z % Mod;
        return ;
    }
    pushdown(k);
    int mid = (tree[k].l + tree[k].r) >> 1;
    if(mid >= x) update_mul(Lson);
    if(mid < y) update_mul(Rson);
    tree[k].sum = (tree[Lson].sum + tree[Rson].sum) % Mod;
}

 

0x04

还是放上完整的代码

#include <iostream>
#include <cstdio>
#define Lson (k << 1)
#define Rson (k << 1) + 1

typedef long long LL;
const int maxn = 4e5+3;
LL n, m, Mod, x, y, z, c;
struct node {
    LL l, r, sum, addtag, multag;
}tree[maxn];
LL xx, f; char ch;
inline LL read() {
    xx = 0, f = 1; ch = getchar();
    while (ch < '0' || ch > '9') {
        if(ch == '-') f = -1;
        ch = getchar();
    }
    while (ch <= '9' && ch >= '0') {
        xx = xx * 10 + ch - '0';
        ch = getchar();
    }
    return xx * f;
}
inline void build(int k, int ll, int rr) {
    tree[k].l = ll, tree[k].r = rr;
    tree[k].addtag = 0, tree[k].multag = 1;
    if(tree[k].l == tree[k].r) {
        tree[k].sum = read();
        tree[k].sum %= Mod;
        return ;
    }
    int mid = (tree[k].l + tree[k].r) >> 1;
    build(Lson, tree[k].l, mid);
    build(Rson, mid + 1, tree[k].r);
    tree[k].sum = tree[Lson].sum + tree[Rson].sum;
    tree[k].sum %= Mod;
}
inline void pushdown(int k) {
    tree[Lson].multag = tree[k].multag * tree[Lson].multag % Mod;
    tree[Rson].multag = tree[k].multag * tree[Rson].multag % Mod;
    tree[Lson].addtag = tree[Lson].addtag * tree[k].multag % Mod;
    tree[Rson].addtag = tree[Rson].addtag * tree[k].multag % Mod;
    tree[Lson].sum = tree[Lson].sum * tree[k].multag % Mod;
    tree[Rson].sum = tree[Rson].sum * tree[k].multag % Mod;
    tree[Lson].addtag = (tree[k].addtag + tree[Lson].addtag) % Mod;
    tree[Rson].addtag = (tree[k].addtag + tree[Rson].addtag) % Mod;
    int L = (tree[Lson].r - tree[Lson].l + 1);
    int R = (tree[Rson].r - tree[Rson].l + 1);
    tree[Lson].sum = (tree[Lson].sum + L * tree[k].addtag) % Mod;
    tree[Rson].sum = (tree[Rson].sum + R * tree[k].addtag) % Mod;
    tree[k].addtag = 0, tree[k].multag = 1;
}
inline void update_mul(int k) {
    if(tree[k].l >= x && tree[k].r <= y) {
        tree[k].sum = tree[k].sum * z % Mod;
        tree[k].addtag = tree[k].addtag * z % Mod;
        tree[k].multag = tree[k].multag * z % Mod;
        return ;
    }
    pushdown(k);
    int mid = (tree[k].l + tree[k].r) >> 1;
    if(mid >= x) update_mul(Lson);
    if(mid < y) update_mul(Rson);
    tree[k].sum = (tree[Lson].sum + tree[Rson].sum) % Mod;
}
inline void update_add(int k) {
    if(tree[k].l >= x && tree[k].r <= y) {
        tree[k].sum = (tree[k].sum + (tree[k].r - tree[k].l + 1) * z) % Mod;
        tree[k].addtag = (tree[k].addtag + z) % Mod;
        return ;
    }
    pushdown(k);
    int mid = (tree[k].l + tree[k].r) >> 1;
    if(mid >= x) update_add(Lson);
    if(mid < y) update_add(Rson);
    tree[k].sum = (tree[Lson].sum + tree[Rson].sum) % Mod;
}
inline LL check(int k) {
    if(tree[k].l > y || tree[k].r < x) return 0;
    if(tree[k].l >= x && tree[k].r <= y) {
        return tree[k].sum % Mod;
    }
    pushdown(k);
    int mid = (tree[k].l + tree[k].r) >> 1;
    return (check(Lson) + check(Rson)) % Mod;
}

int main() {
    n = read(), m = read(), Mod = read();
    build(1, 1, n);
    for(int i=1; i<=m; i++) {
        c = read();
        switch(c) {
            case 1:
                x = read(), y = read(), z = read();
                update_mul(1);
                break;
            case 2:
                x = read(), y = read(), z = read();
                update_add(1);
                break;
            case 3:
                x = read(), y = read();
                printf("%lld\n", check(1));
        }
    }
}

 

0x05

还有很多其他类型的线段树,比如超哥线段树、吉司机线段树、zkw线段树什么的。这里不再深入讲解。

感兴趣的同学可以自行百度

 

posted @ 2018-08-06 17:37  Mystical-W  阅读(201)  评论(0编辑  收藏  举报