线段树初期汇总

浅谈线段树及其应用

———————————————————————————————————————————————————————————————————————————————————————————

概要

本文介绍了线段树的基本原理和实现,举了一些典型的可以用线段树解决的问题并进行分析和解答。

本文代码适用于c++

本文代码中的通用宏定义

#define ll long long
#define ls(p) (p<<1)
#define rs(p) (p<<1|1)
#define mid ((l+r)>>1)

线段树的简单理解及实现

问题的引入:

下面考虑这样一个问题:给定一个数组a[N], 多次查询 [ l , r ] 的区间和。

当然我们可以用树状数组实现快速查询,但是如果我们增加多次对 [ l, r ] 的区间修改呢?树状数组的单点修改会使复杂度过高,从而TLE。

为了解决对区间的修改与查询的操作,我们引入线段树。

线段树的原理:

简单说,线段树就是通过将长度非一的区间二分递归以维护区间特征值的方法。

如下图,对于给定的数组a[1]--a[5] :

  1. 令d [ i ] 存储某一区间的特征值(这里以区间和为例子),d[1]存储的是 [ 1 , 5 ] 的区间和,d[2]存储的是 [ 1 , 3 ] 的区间和,d[3]存储的是 [ 4 , 5 ] 的区间和。

  2. 在图中的树中d[i]的左右子节点分别为 d [ i * 2 ] 和 d [ i * 2 + 1 ] 。

  3. d[i]存储 [ l , r ] 的特征值,那么左右节点分别存储 [ l , mid ] , [ mid+1 ,r ]的特征值( mid=(l+r)/2 )

代码实现

注意:由于线段树的存储方式,d的长度是a的长度的至少四倍

上传
void push_up(int p)//根据左右子节点更新d[p];
{
    d[p] = d[ls(p)] + d[rs(p)];
}
建树
void build(int p, int l, int r)//建立[l,r]的线段树,当前访问的节点为p
{
    t[p] = 0;
    if (l == r) return d[p] = a[l], void();
    build(ls(p), l, mid), build(rs(p), mid + 1, r);
    push_up(p);
}
查询
ll ask(int p, int l, int r, int L, int R)//查询[L,R]的区间和,当前访问节点p,访问区间[l,r];
{
    if (l > R || r < L)
        return 0;
    if (L <= l && r <= R)
        return d[p];
    {
        //push_down(p, l, r);暂时先不管它
        return ask(ls(p), l, mid, L, R) + ask(rs(p), mid + 1, r, L, R);
    }
}    
修改

如果我们每一次对区间的修改都进行到叶节点的话,时间复杂度过高。

我们设法当且仅当查询/访问到某个节点时,我们才对它修改。

//利用t[p]来记录t[p]是否作了修改
void add(int p, int l, int r, ll x)//对节点p进行修改:每个元素加x;当前访问[l,r]区间;
{
    d[p] = (d[p] + x * (r - l + 1));
    t[p] = t[p] + x;
}
void push_down(int p, int l, int r)//下传,对d[p]的左右子节点进行修改;当前访问区间[l,r];
{
    if (!t[p]) return ;//t[p]标记为0,则不对d[p]作修改;
    add(ls(p), l, mid, t[p]), add(rs(p), mid + 1, r, t[p]);//分别修改d[p]的左右子节点;
    t[p] = 0;//清空t[p]标记;
}

void add(int p, int l, int r, int L, int R, ll x)//当前访问节点p,访问区间[l,r];对[L,R]区间中每个元素加x;
{
    if (l > R || r < L)//访问区间[l,r]与修改区间[L,R]没有交集;
        return ;
    if (L <= l && r <= R)//访问区间[l,r]包含于修改区间[L,R];
    {
        add(p, l, r, x);
        return;
    }
    else
    {
        push_down(p, l, r);//先对左右子节点进行修改
        add(ls(p), l, mid, L, R, x), add(rs(p), mid + 1, r, L, R, x);
        push_up(p);
    }
}

现在我们知道为什么在ask函数中设置push_down了。

例题

均来自Luogu

———————————————————————————————————————————————————————————————————————————————————————————

P3372 【模板】线段树 1

———————————————————————————————————————————————————————————————————————————————————————————
分析

直接由上面的代码即可得出。

代码实现
#include<bits/stdc++.h>
#define ll long long
#define mid ((l+r)>>1)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
const int maxn = 1e5+10;
ll a[maxn], d[maxn << 2], t[maxn << 2];
int n, m;
using namespace std;
void push_up(int p)//上传,根据左右子节点更新d[p];
{
    d[p] = d[ls(p)] + d[rs(p)];
}
void build(int p, int l, int r)//建立[l,r]的线段树,当前访问的节点为p
{
    t[p] = 0;
    if (l == r) return d[p] = a[l], void();
    build(ls(p), l, mid), build(rs(p), mid + 1, r);
    push_up(p);
}
//利用t[p]来记录t[p]是否作了修改
void add(int p, int l, int r, ll x)//对节点p进行修改:每个元素加x;当前访问[l,r]区间;
{
    d[p] = (d[p] + x * (r - l + 1));
    t[p] = t[p] + x;
}
void push_down(int p, int l, int r)//下传,对d[p]的左右子节点进行修改;当前访问区间[l,r];
{
    if (!t[p]) return ;//t[p]标记为0,则不对d[p]作修改;
    add(ls(p), l, mid, t[p]), add(rs(p), mid + 1, r, t[p]);//分别修改d[p]的左右子节点;
    t[p] = 0;//清空t[p]标记;
}

void add(int p, int l, int r, int L, int R, ll x)//当前访问节点p,访问区间[l,r];对[L,R]区间中每个元素加x;
{
    if (l > R || r < L)//访问区间[l,r]与修改区间[L,R]没有交集;
        return ;
    if (L <= l && r <= R)//访问区间[l,r]包含于修改区间[L,R];
    {
        add(p, l, r, x);
        return;
    }
    else
    {
        push_down(p, l, r);//先对左右子节点进行修改
        add(ls(p), l, mid, L, R, x), add(rs(p), mid + 1, r, L, R, x);
        push_up(p);
    }
}
ll ask(int p, int l, int r, int L, int R)//查询[L,R]的区间和,当前访问节点p,访问区间[l,r];
{
    if (l > R || r < L)
        return 0;
    if (L <= l && r <= R)
        return d[p];
    {
        push_down(p, l, r);暂时先不管它
        return ask(ls(p), l, mid, L, R) + ask(rs(p), mid + 1, r, L, R);
    }
}
int main()
{
    ios::sync_with_stdio(0);
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    build(1, 1, n);
    while (m--)
    {
        ll op, x, y, k;
        cin >> op >> x >> y;
        if (op == 1)
        {
            cin >> k;
            add(1, 1, n, x, y, k);
        }
        else
            cout << ask(1, 1, n, x, y) << "\n";
    }
    return 0;
}

———————————————————————————————————————————————————————————————————————————————————————————

P3373 【模板】线段树 2

———————————————————————————————————————————————————————————————————————————————————————————

分析
  1. 本题与上题增加了对区间乘法修改的需求

  2. 考虑如何实现区间乘法:显然 [ l , r ] 中每一个元素乘k,那么区间和乘k;

  3. 由于本题需要实现区间加法的修改,我们还需要考虑下传时加法和乘法的顺序

容易想到,如果修改操作中先乘再加,那么代码中先乘再加

如果修改操作中先加再乘,我们考虑和上面代码中的顺序保持一致:不难发现 (d+x)y=dy+xy

因此我们将乘法标记初始化为1,加法标记初始化为0;

先让加法标记和区间和乘上乘法标记,再让区间和与加法标记相加即可。

代码实现

只需综合分析,对上一题代码稍作修改

#include<bits/stdc++.h>
#define ll long long
#define mid ((l+r)>>1)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
const int maxn = 1e5+10;
ll a[maxn], d[maxn << 2], t[maxn << 2], t1[maxn << 2];
int n, m, pp ;
using namespace std;
void push_up(int p)
{
    d[p] = (d[ls(p)] + d[rs(p)]) % pp;
}
void build(int p, int l, int r)
{
    t[p] = 0;
    t1[p] = 1;
    if (l == r) return d[p] = a[l], void();
    build(ls(p), l, mid), build(rs(p), mid + 1, r);
    push_up(p);
}
void add(int p, int l, int r, ll x)
{
    d[p] = (d[p] + x * ((r - l + 1) % pp)) % pp;
    t[p] = (t[p] + x) % pp;
}
void mul(int p, int l, int r, ll x)//区间乘法
{
    d[p] = (d[p] * (x % pp)) % pp;
    t[p] = ( t[p] * (x % pp )) % pp;
    t1[p] = (t1[p] * (x % pp )) % pp;
}
void push_down(int p, int l, int r)
{
    if ((t1[p] == 1) && !t[p]) return ;
    mul(ls(p), l, mid, t1[p]), mul(rs(p), mid + 1, r, t1[p]);//下传乘法标记
    add(ls(p), l, mid, t[p]), add(rs(p), mid + 1, r, t[p]);//下传加法标记
    t[p] = 0;
    t1[p] = 1;
}
void add(int p, int l, int r, int L, int R, ll x)
{
    if (l > R || r < L)
        return ;
    if (L <= l && r <= R)
    {
        add(p, l, r, x);
        return;
    }
    else
    {
        push_down(p, l, r);
        add(ls(p), l, mid, L, R, x), add(rs(p), mid + 1, r, L, R, x);
        push_up(p);
    }
}
void mul(int p, int l, int r, int L, int R, ll x)//区间乘法
{
    if (l > R || r < L)
        return ;
    if (L <= l && r <= R)
    {
        mul(p, l, r, x);
        return;
    }
    else
    {
        push_down(p, l, r);
        mul(ls(p), l, mid, L, R, x), mul(rs(p), mid + 1, r, L, R, x);
        push_up(p);
    }
}
ll ask(int p, int l, int r, int L, int R)
{
    if (l > R || r < L)
        return 0;
    if (L <= l && r <= R)
        return d[p];
    {
        push_down(p, l, r);
        return (ask(ls(p), l, mid, L, R) + ask(rs(p), mid + 1, r, L, R)) % pp;
    }
}
int main()
{
    ios::sync_with_stdio(0);
    cin >> n >> m >> pp;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    build(1, 1, n);
    while (m--)
    {
        ll op, x, y, k;
        cin >> op >> x >> y;
        if (op == 1)
        {
            cin >> k;
            mul(1, 1, n, x, y, k);
        }
        else if (op == 2)
        {
            cin >> k;
            add(1, 1, n, x, y, k);
        }
        else
            cout << ask(1, 1, n, x, y) % pp << "\n";
    }
    return 0;
}

———————————————————————————————————————————————————————————————————————————————————————————

P1471 方差

———————————————————————————————————————————————————————————————————————————————————————————

分析
  1. 需要维护的区间特征值:平均数和方差

  2. 平均数:区间和/区间长度,区间和我们已经在前面的例子中说明了

  3. 方差:直接将方差作为特征值存储到线段树中,对某一区间的修改将变得十分复杂,我们尝试对方差公式进行恒等变形,看看能不能用几个易于维护的特征值表示方差:

由上图,我们只需要再维护区间平方和就可以求得区间的方差

  1. 平方和:上传与区间和一致

区间修改如图

借助区间和即可

注:应先修改平方和,再修改区间和

代码实现
#include<bits/stdc++.h>
#define ls(p) (p<<1)
#define rs(p) (p<<1|1)
#define mid ((l+r)>>1)
using namespace std;
const int N = 1e5+10;
int n, m;
double a[N], d[N << 2], d2[N << 2], t[N << 2];//d1存储区间和,d2存储区间平方和,t作为修改标记
void push_up(int p)
{
    d[p] = d[ls(p)] + d[rs(p)];
    d2[p] = d2[ls(p)] + d2[rs(p)];
}
void build(int p, int l, int r)
{
    t[p] = 0;
    if (l == r)
    {
        d[p] = a[l];
        d2[p] = a[l] * a[l];
        return ;
    }
    build(ls(p), l, mid), build(rs(p), mid + 1, r);
    push_up(p);
}
void add(int p, int l, int r, double x)
{

    d2[p] +=   2 * x * d[p] + x * x * (r - l + 1) ;//修改区间平方和
    d[p] +=  x * (r - l + 1);//修改区间和
    t[p] += x;

}
void push_down(int p, int l, int r)
{
    if (!t[p])
        return ;
    d2[ls(p)] +=  2 * t[p] * d[ls(p)] + t[p] * t[p] * (mid - l + 1) ;
    d2[rs(p)] +=  2 * t[p] * d[rs(p)] + t[p] * t[p] * (r - mid) ;
    d[ls(p)] += t[p] * (mid - l + 1), d[rs(p)] += t[p] * (r - mid);
    t[ls(p)] += t[p], t[rs(p)] += t[p];
    t[p] = 0;
}
void add(int p, int l, int r, double x, int L, int R)
{
    if (l > R || L > r)
        return ;
    if (L <= l && r <= R)
    {
        add(p, l, r, x);
        return ;
    }
    push_down(p, l, r);
    add(ls(p), l, mid, x, L, R), add(rs(p), mid + 1, r, x, L, R);
    push_up(p);


}
double ask1(int p, int l, int r, int L, int R)//查询[l,r]的区间和
{
    if (l > R || L > r)
        return 0;
    if (L <= l && r <= R)
        return d[p];
    push_down(p, l, r);
    return ask1(ls(p), l, mid,  L, R) + ask1(rs(p), mid + 1, r,  L, R);

}
double ask2(int p, int l, int r, int L, int R)//查询[l,r]的区间平方和
{
    if (l > R || L > r)
        return 0;
    if (L <= l && r <= R)
        return d2[p];
    push_down(p, l, r);
    return ask2(ls(p), l, mid,  L, R) + ask2(rs(p), mid + 1, r,  L, R);
}
int main()
{
    ios::sync_with_stdio(0);
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    build(1, 1, n);
    while (m--)
    {
        int op, x, y;
        double  k;
        cin >> op >> x >> y;
        if (op == 1)
        {
            cin >> k;
            add(1, 1, n, k, x, y);
        }
        else if (op == 2)
        {
            printf("%.4lf\n", ask1(1, 1, n, x, y) / (y - x + 1));
        }
        else
        {
            double ans = (ask2(1, 1, n, x, y)) / (y - x + 1) -
            (ask1(1, 1, n, x, y) / (y - x + 1)) * (ask1(1, 1, n, x, y) / (y - x + 1));
            printf("%.4lf\n", ans);
        }
    }
    return 0;
}

———————————————————————————————————————————————————————————————————————————————————————————

P4513 小白逛公园

———————————————————————————————————————————————————————————————————————————————————————————

分析
  1. 查询连续子段和的最大值(不得为空):

d[p]对应区间的最大子段和只有如图的几种情况

  1. 由1.我们知道:至少要维护区间的最大前缀和,最大后缀和,最大连续子段和这几个特征值

  2. 考虑最大前/后缀和的维护:

d[p]对应区间的最大前/后缀和只有如图的两种情况

代码实现

注:查询函数ask与上文有些许不同

查询区间(阴影)与访问区间(d[p])无非如图的三种关系(在阴影包含于d[p]时),如果照搬上面三题中ask函数的判断逻辑(查询区间超出访问区间的部分将不再计算)那么第三种情况中求出的是查询区间分割成两部分后的两个区间最大连续子段和,不能转化为原查询区间的连续子段和,所以我们对ask函数的判断逻辑作出修改(详见代码中ask)

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ls(p) (p << 1)
#define rs(p) (p << 1 | 1)
#define mid ((l + r) >> 1)
const int N = 1e5 + 10;
const ll INF = 1e18;

ll a[N];
//用结构体Node存储区间特征值
struct Node
{
    ll maxl, maxx, maxr, sum;
} d[N << 2];

int n, m;
//合并left的特征值 和 right的特征值 , 得到新的特征值
Node merge(Node left, Node right)
{
    Node res;
    res.sum = left.sum + right.sum;
    res.maxl = max(left.maxl, left.sum + right.maxl);
    res.maxr = max(right.maxr, right.sum + left.maxr);
    res.maxx = max({left.maxx, right.maxx, left.maxr + right.maxl});
    return res;
}

void push_up(int p)
{
    d[p] = merge(d[ls(p)], d[rs(p)]);
}

void build(int p, int l, int r)
{
    if (l == r)
    {
        d[p].sum = a[l];
        d[p].maxl = a[l];
        d[p].maxr = a[l];
        d[p].maxx = a[l];
        return;
    }
    build(ls(p), l, mid);
    build(rs(p), mid + 1, r);
    push_up(p);
}

void update(int p, int l, int r, int pos, ll val)//将a[pos]更改为val
{
    if (l == r)
    {
        a[l] = val;
        d[p].sum = val;
        d[p].maxl = val;
        d[p].maxr = val;
        d[p].maxx = val;
        return;
    }
    if (pos <= mid)//pos在左节点对应区间
    {
        update(ls(p), l, mid, pos, val);
    }
    else//pos在右节点对应区间
    {
        update(rs(p), mid + 1, r, pos, val);
    }
    push_up(p);
}

Node ask(int p, int l, int r, int L, int R)
{
    // 完全包含
    if (L <= l && r <= R)
    {
        return d[p];
    }

    // 只在左子树
    if (R <= mid)
    {
        return ask(ls(p), l, mid, L, R);
    }

    // 只在右子树
    if (L >= mid+1)
    {
        return ask(rs(p), mid + 1, r, L, R);
    }

    // 跨越中点:分别查询[L, mid]和[mid+1, R]
    Node left = ask(ls(p), l, mid, L, mid);
    Node right = ask(rs(p), mid + 1, r, mid + 1, R);
    return merge(left, right);
}

int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);

    cin >> n >> m;
    for (int i = 1; i <= n; i++)
    {
        cin >> a[i];
    }
    build(1, 1, n);

    while (m--)
    {
        int k, x, y;
        cin >> k >> x >> y;

        if (k == 1)
        {
            if (x > y) swap(x, y);
            Node res = ask(1, 1, n, x, y);
            cout << res.maxx << "\n";
        }
        else
        {
            update(1, 1, n, x, y);
        }
    }
    return 0;
}

———————————————————————————————————————————————————————————————————————————————————————————

P7492 [传智杯 #3 决赛] 序列

分析

1. 需要对区间修改,并且需要区间最大连续子段和查询

注:这里子段可以为空,此时子段和为0;

笔者因为没考虑到这一点调试代码浪费好长时间

  1. 区间最大非空子段和的查询同上一道题,若最大非空子段和为负数,那么输出0即可满足本题要求

下面考虑如何快速对区间进行按位或的操作:

  1. 如果k=0,相当于没有修改

  2. 如果a[x]|k=a[x],相当于没有修改

  3. 按位或操作最多执行到每一位都是1,操作次数有上限,综合3,4我们可以跳过许多不必要的线段树的遍历/修改

  4. 放到区间里面来看,如果某区间的所有元素的在k中为1的位数也都为1的话,那么不必进行修改

  5. 第6.的判断我们可以借助按位与来完成:区间特征值一位为1,当且仅当区间中所有元素的该位是1

struct Node
{
    ll max1, max2, maxn, sum;
    ll and_val;  // 区间按位与
} d[N << 2];
void push_up(int p)
{
    ...
    d[p].and_val = d[ls(p)].and_val & d[rs(p)].and_val;  // 维护按位与
}
代码实现

这里ask函数如同上一题

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define ls(p) (p<<1)
#define rs(p) (p<<1|1)
#define mid ((l+r)>>1)
const int N = 1e5+10;
const ll INF = -1e15+10;

struct Node
{
    ll max1, max2, maxn, sum;
    ll and_val;  // 区间按位与
} d[N << 2];

ll a[N];
int n, m;

void push_up(int p)
{
    d[p].sum = d[ls(p)].sum + d[rs(p)].sum;
    d[p].max1 = max(d[ls(p)].max1, d[ls(p)].sum + d[rs(p)].max1);
    d[p].max2 = max(d[rs(p)].max2, d[ls(p)].max2 + d[rs(p)].sum);
    d[p].maxn = max({d[ls(p)].maxn, d[rs(p)].maxn, d[ls(p)].max2 + d[rs(p)].max1});
    d[p].and_val = d[ls(p)].and_val & d[rs(p)].and_val;  // 维护按位与
}

void build(int p, int l, int r)
{
    if (l == r)
    {
        d[p] = {a[l], a[l], a[l], a[l], a[l]};
        return;
    }
    build(ls(p), l, mid);
    build(rs(p), mid + 1, r);
    push_up(p);
}

void update(int p, int l, int r, int L, int R, ll k)
{
    if (r < L || l > R) return;

    if (l == r)    // 叶子节点直接更新
    {
        if ((a[l] | k) == a[l]) return;
        a[l] |= k;
        d[p] = {a[l], a[l], a[l], a[l], a[l]};
        return;
    }


    if (L <= l && r <= R)
    {
        if (k == 0) return;
        if ((k & ~d[p].and_val) == 0) return;    // 剪枝:k的1位在区间内已全为1
    }

    if (R <= mid) update(ls(p), l, mid, L, R, k);
    else if (L > mid) update(rs(p), mid + 1, r, L, R, k);
    else
    {
        update(ls(p), l, mid, L, R, k);
        update(rs(p), mid + 1, r, L, R, k);
    }
    push_up(p);
}

Node ask(int p, int l, int r, int L, int R)
{
    if (L <= l && r <= R) return d[p];

    if (R <= mid) return ask(ls(p), l, mid, L, R);
    if (L > mid) return ask(rs(p), mid + 1, r, L, R);

    Node left = ask(ls(p), l, mid, L, mid);
    Node right = ask(rs(p), mid + 1, r, mid + 1, R);
    Node res;
    res.sum = left.sum + right.sum;
    res.max1 = max(left.max1, left.sum + right.max1);
    res.max2 = max(right.max2, left.max2 + right.sum);
    res.maxn = max({left.maxn, right.maxn, left.max2 + right.max1});
    return res;
}

int main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n >> m;
    d[0] = {INF, INF, INF, 0, 0};  // 初始化空节点
    for (int i = 1; i <= n; ++i) cin >> a[i];
    build(1, 1, n);

    while (m--)
    {
        ll op, l, r, k;
        cin >> op;
        if (op == 1)
        {
            cin >> l >> r;
            if (ask(1, 1, n, l, r).maxn >= 0)
                cout << ask(1, 1, n, l, r).maxn << '\n';
            else
                cout << 0 << "\n";
        }
        else
        {
            cin >> l >> r >> k;
            update(1, 1, n, l, r, k);  // 区间更新
        }
    }
    return 0;
}

———————————————————————————————————————————————————————————————————————————————————————————

硬是要总结

线段树十分利好多次查询区间特征值的问题,尤其当特征值满足结合律,可差分时。

当遇到区间特征值的修改与查询时,不妨考虑线段树。

参考文献

  1. OI-WIKI OI WIKI 线段树

  2. 《统计的力量》——张昆玮github.com

  3. 《深入浅出程序设计(进阶篇)》——洛谷

  4. 个人博客树状数组和线段树基础 - Ahui2667d - 博客园

posted @ 2025-12-01 22:07  Ahui2667d  阅读(0)  评论(0)    收藏  举报