【高级·数据结构】线段树
前言
线段树是非常重要的一个数据结构!可以求解区间 $[l,r]$ 的最值问题等区间查询。
而且大部分ST 表,树状数组能解决的问题线段树都能解决。
普通线段树
P3372 【模板】线段树 1
有一个序列 $a$,现在有两个操作:
1 l r k
表示将 $[l,r]$ 区间的元素全部加上 $k$。2 l r
表示查询 $[l,r]$ 区间的和。
暴力思想
暴力的话很好处理,每次暴力枚举 $i$ 从 $l$ 到 $r$ 这段区间,给每个元素加上 $k$ 即可。
然后查询的时候暴力枚举 $i$ 从 $l$ 到 $r$ 这段区间,用 $res$ 来统计元素值的和即可。
但是这样的话对于数据规模大的时候显然会超时。
线段树的实现
那么这个时候我们就可以考虑线段树求解了!首先我们要了解线段树是什么。
线段树其实就是一种二叉树,对于一个线段我们用二叉树来表示,每个节点会有自己的区间。
对于根节点来说,管理区间就是 $[1,n]$ 啦,而它的左孩子和右孩子则管理的是 $[1,mid]$ 和 $[mid+1,n]$ ,这里的 $mid$ 指的就是 $\frac{1+n}{2}$ 可以发现,节点管理的区间其实是二分下去的,这就是线段树的建树过程。
void build(ll x,ll l,ll r){
tr[x].l=l;
tr[x].r=r;
if(l==r){
tr[x].val=arr[l];
return ;
}
ll m=l+r>>1;
build(lc,l,m);
build(rc,m+1,r);
pushUp(x);
return ;
}
我们定义了一个结构体,储存该节点管理的区间 $l,r$ ,当 $l=r$ 的时候表示到达了叶子节点,那么当前管理的值 $val$ 就应该等于数组中的初始值,当然如果没有初始值直接 $\text{return}$ 就好。
之后二分下去即可,但是你会发现有一个 $\text{pushUp}$ 函数,其实这个函数就是合并一个节点的左右两个节点的,本质上来说,线段树建树的过程就是递归,所以会有合并结点的过程,合并什么呢?具体看题目要求什么了,这里以区间求和为例。
void pushUp(ll x){
tr[x].val=tr[lc].val+tr[rc].val;
return ;
}
这个就是 $\text{pushUp}$ 函数了,非常简洁,就是合并左节点和右节点,这里的 $\text{lc}$ 和 $\text{rc}$ 指的是左孩子和右孩子。
那么建树完毕之后就应该有查询了,我们看看主函数:
int main(){
cin>>n>>q;
for(int i=1;i<=n;i++) cin>>arr[i];
build(1,1,n);
while(q--){
cin>>op>>x>>y;
if(op==1){
cin>>k;
update(1,x,y,k);
}
else cout<<ask(1,x,y)<<'\n';
}
return 0;
}
在建完树之后询问,如果当前要更新值,那么就调用 $\text{update}$ 函数,否则输出答案,调用 $\text{ask}$ 函数。
那么 $\text{update}$ 函数怎么写呢?
我们考虑,将一个区间加上 $k$,其实就是把这个区间涂上一个 $len \times k$ 的标记,这里的 $len$ 为该节点管理区间的长度,为什么呢?
该节点是不是管理 $[l,r]$ 的区间,那是不是 $[l,r]$ 的区间中的元素都要加上 $k$?所以管理一个区间的节点就要加上一个区间的 $k$ 值。那么什么时候要加呢?如果这个区间被完全包括在目标区间里面,那么这个节点就要加,这个很容易理解。
但是我要查询的时候是不是还要把这个标记传下去一直累加?所以我们引出了 $\text{lazytag}$ 也就是懒标记来维护,每次加值的时候让值加上 $len \times k$ ,然后让 $\text{lazytag}$ 加上 $k$ 即可,然后我们还要写一个函数让懒标记下传,所以我们新建一个 $\text{pushDown}$ 函数即可。
void update(ll x,ll l,ll r,ll k){
if(l<=tr[x].l && tr[x].r<=r){
tr[x].val+=k*tr[x].len();
tr[x].lazy+=k;
return ;
}
pushDown(x);
ll m=tr[x].l+tr[x].r>>1;
if(l<=m) update(lc,l,r,k);
if(r>m) update(rc,l,r,k);
pushUp(x);
return ;
}
我们先看更新函数,当这个区间被完全包括在目标区间里面,那么这个节点就要加,然后下传懒标记,之后再二分递归即可,记得还要合并!
void pushDown(ll x){
if(!tr[x].lazy) return ;
tr[lc].lazy+=tr[x].lazy;
tr[rc].lazy+=tr[x].lazy;
tr[lc].val+=tr[x].lazy*tr[lc].len();
tr[rc].val+=tr[x].lazy*tr[rc].len();
tr[x].lazy=0;
return ;
}
我们看下传函数,如果现在没有标记就不下传,然后左右节点的懒标记继承父节点的懒标记,因为要继续下传。然后值就加上 $len \times k$ ,我们之前讲过,然后标记传完了,归 $0$ 即可。
好了我们来看询问函数:
ll ask(ll x,ll l,ll r){
if(l<=tr[x].l && tr[x].r<=r) return tr[x].val;
pushDown(x);
ll m=tr[x].l+tr[x].r>>1;
ll ans=0;
if(l<=m) ans+=ask(lc,l,r);
if(r>m) ans+=ask(rc,l,r);
return ans;
}
如果这个区间被完全包括在目标区间里面,那么直接返回值即可,然后下传一遍懒标记。
注意:答案需要在两个区间累加,可能一个区间中会被两个节点分别管理!
之后我们将函数组合到一起即可:
#include<bits/stdc++.h>
#define ll long long
#define lc x<<1
#define rc x<<1|1
using namespace std;
const int N=1e5+5;
struct node{
ll l,r,val,lazy;
ll len(){
return r-l+1;
}
}tr[N<<2];
ll arr[N];
ll n,q,x,y,k,op;
void pushUp(ll x){
tr[x].val=tr[lc].val+tr[rc].val;
return ;
}
void pushDown(ll x){
if(!tr[x].lazy) return ;
tr[lc].lazy+=tr[x].lazy;
tr[rc].lazy+=tr[x].lazy;
tr[lc].val+=tr[x].lazy*tr[lc].len();
tr[rc].val+=tr[x].lazy*tr[rc].len();
tr[x].lazy=0;
return ;
}
void build(ll x,ll l,ll r){
tr[x].l=l;
tr[x].r=r;
if(l==r){
tr[x].val=arr[l];
return ;
}
ll m=l+r>>1;
build(lc,l,m);
build(rc,m+1,r);
pushUp(x);
return ;
}
void update(ll x,ll l,ll r,ll k){
if(l<=tr[x].l && tr[x].r<=r){
tr[x].val+=k*tr[x].len();
tr[x].lazy+=k;
return ;
}
pushDown(x);
ll m=tr[x].l+tr[x].r>>1;
if(l<=m) update(lc,l,r,k);
if(r>m) update(rc,l,r,k);
pushUp(x);
return ;
}
ll ask(ll x,ll l,ll r){
if(l<=tr[x].l && tr[x].r<=r) return tr[x].val;
pushDown(x);
ll m=tr[x].l+tr[x].r>>1;
ll ans=0;
if(l<=m) ans+=ask(lc,l,r);
if(r>m) ans+=ask(rc,l,r);
return ans;
}
int main(){
cin>>n>>q;
for(int i=1;i<=n;i++) cin>>arr[i];
build(1,1,n);
while(q--){
cin>>op>>x>>y;
if(op==1){
cin>>k;
update(1,x,y,k);
}
else cout<<ask(1,x,y)<<'\n';
}
return 0;
}
P3870 [TJOI2009] 开关
问题简述
现有 $n$ 盏灯排成一排,从左到右依次编号为:$1$,$2$,……,$n$。然后依次执行 $m$ 项操作。
操作分为两种:
- 指定一个区间 $[a,b]$,然后改变编号在这个区间内的灯的状态(把开着的灯关上,关着的灯打开);
- 指定一个区间 $[a,b]$,要求你输出这个区间内有多少盏灯是打开的。
灯在初始时都是关着的
思路
我们发现,每次开关灯其实就是对一段区间取反,那么我们用 $val$ 来记录区间内灯开着的数量。
那么每次下传的时候该怎么修改呢?
我们发现,每个灯不是开着就是关着,所以每次去反之后 $val$ 就等于 $len-val$ 了。然后懒标记取反即可。
代码实现
#include<bits/stdc++.h>
#define ll long long
#define lc x<<1
#define rc x<<1|1
using namespace std;
const int N=5e5;
struct node{
ll val,l,r,lazy;
ll len(){
return r-l+1;
}
}tr[N<<2];
ll n,q,op;
void pushUp(ll x){
tr[x].val=tr[lc].val+tr[rc].val;
}
void pushDown(ll x){
if(!tr[x].lazy) return ;
tr[lc].val=tr[lc].len()-tr[lc].val;
tr[rc].val=tr[rc].len()-tr[rc].val;
tr[lc].lazy^=1;
tr[rc].lazy^=1;
tr[x].lazy=0;
}
void buildTree(ll x,ll l,ll r){
tr[x].l=l;
tr[x].r=r;
if(l==r){
tr[x].val=0;
return ;
}
ll m=(l+r)>>1;
buildTree(lc,l,m);
buildTree(rc,m+1,r);
pushUp(x);
}
void update(ll x,ll l,ll r){
if(l<=tr[x].l && tr[x].r<=r){
tr[x].val=tr[x].len()-tr[x].val;
tr[x].lazy^=1;
return ;
}
pushDown(x);
ll m=(tr[x].l+tr[x].r)>>1;
if(l<=m) update(lc,l,r);
if(r>m) update(rc,l,r);
pushUp(x);
}
ll query(ll x,ll l,ll r){
if(l<=tr[x].l && tr[x].r<=r) return tr[x].val;
pushDown(x);
ll m=(tr[x].l+tr[x].r)>>1,sum=0;
if(l<=m) sum+=query(lc,l,r);
if(r>m) sum+=query(rc,l,r);
return sum;
}
int main(){
cin>>n>>q;
buildTree(1,1,n);
while(q--){
ll l,r;
cin>>op>>l>>r;
if(op==0){
update(1,l,r);
}
else{
cout<<query(1,l,r)<<endl;
}
}
return 0;
}
多懒标记下放线段树
P1253 扶苏的问题
问题简述
给定一个长度为 $n$ 的序列 $a$,要求支持如下三个操作:
- 给定区间 $[l, r]$,将区间内每个数都修改为 $x$。
- 给定区间 $[l, r]$,将区间内每个数都加上 $x$。
- 给定区间 $[l, r]$,求区间内的最大值。
思路
其实多懒标记下放,就是有很多个懒标记,在 $\text{pushdown}$ 函数中下方多个懒标记,就比如这道题。
首先我们看第二个操作和第三个操作,都很简单,所以我们直接维护区间最大值即可。
void pushUp(ll x){
tr[x].maxn=max(tr[lc].maxn,tr[rc].maxn);
}
void update_add(ll x,ll l,ll r,ll L,ll R,ll k){
if(l<=L && R<=r){
pushDown_set(x);
tr[x].maxn+=k;
tr[x].lazy_max+=k;
return ;
}
pushDown(x);
ll m=L+R>>1;
if(l<=m) update_add(lc,l,r,L,m,k);
if(r>m) update_add(rc,l,r,m+1,R,k);
pushUp(x);
return ;
}
而第一个操作怎么做呢?我们产尝试维护一个 $\text{lazy_set}$ 的懒标记,然后在 $\text{pushdown}$ 函数的时候需要覆盖 $\text{lazy_max}$ 懒标记,直接设为 $0$ 即可,那么最大值咋办?直接设为 $\text{lazy_set}$ 就好了呀,然后就正常操作即可。
void pushDown_set(ll x){
if(tr[x].lazy_set!=-1e18){
// cout<<"HHH"<<endl;
tr[lc].lazy_max=0;
tr[rc].lazy_max=0;
tr[lc].maxn=tr[x].lazy_set;
tr[rc].maxn=tr[x].lazy_set;
tr[lc].lazy_set=tr[x].lazy_set;
tr[rc].lazy_set=tr[x].lazy_set;
tr[x].lazy_set=-1e18;
}
}
void pushDown(ll x){
pushDown_set(x);
if(tr[x].lazy_max){
pushDown_set(x);
tr[lc].lazy_max+=tr[x].lazy_max;
tr[rc].lazy_max+=tr[x].lazy_max;
tr[lc].maxn+=tr[x].lazy_max;
tr[rc].maxn+=tr[x].lazy_max;
tr[x].lazy_max=0;
return ;
}
return ;
}
最后你可能发现,修改函数也分为 $\text{add}$ 和 $\text{set}$,区别就在懒标记的修改:
if(l<=L && R<=r){
pushDown_set(x);
tr[x].maxn+=k;
tr[x].lazy_max+=k;
return ;
}
if(l<=L && R<=r){
tr[x].maxn=k;
tr[x].lazy_max=0;
tr[x].lazy_set=k;
return ;
}
上面是加函数,下面是修改函数,发现修改的不同,因为修改的时候会覆盖之前的操作,所以会不一样,然后我们就可以写出代码了:
#include<bits/stdc++.h>
#define ll long long
#define lc x<<1
#define rc x<<1|1
using namespace std;
const int N=1e7+5;
struct node{
ll maxn,lazy_max,lazy_set;
}tr[N<<2];
ll arr[N];
ll n,q,op,l,r,k,res;
inline ll read(){
ll x=0,f=1;
char ch=getchar();
while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
return x*f;
}
void pushUp(ll x){
tr[x].maxn=max(tr[lc].maxn,tr[rc].maxn);
}
void pushDown_set(ll x){
if(tr[x].lazy_set!=-1e18){
// cout<<"HHH"<<endl;
tr[lc].lazy_max=0;
tr[rc].lazy_max=0;
tr[lc].maxn=tr[x].lazy_set;
tr[rc].maxn=tr[x].lazy_set;
tr[lc].lazy_set=tr[x].lazy_set;
tr[rc].lazy_set=tr[x].lazy_set;
tr[x].lazy_set=-1e18;
}
}
void pushDown(ll x){
pushDown_set(x);
if(tr[x].lazy_max){
pushDown_set(x);
tr[lc].lazy_max+=tr[x].lazy_max;
tr[rc].lazy_max+=tr[x].lazy_max;
tr[lc].maxn+=tr[x].lazy_max;
tr[rc].maxn+=tr[x].lazy_max;
tr[x].lazy_max=0;
return ;
}
return ;
}
void build(ll x,ll l,ll r){
if(l==r){
tr[x].maxn=arr[l];
tr[x].lazy_set=-1e18;
tr[x].lazy_max=0;
return ;
}
ll m=l+r>>1;
build(lc,l,m);
build(rc,m+1,r);
pushUp(x);
return ;
}
void update_add(ll x,ll l,ll r,ll L,ll R,ll k){
if(l<=L && R<=r){
pushDown_set(x);
tr[x].maxn+=k;
tr[x].lazy_max+=k;
return ;
}
pushDown(x);
ll m=L+R>>1;
if(l<=m) update_add(lc,l,r,L,m,k);
if(r>m) update_add(rc,l,r,m+1,R,k);
pushUp(x);
return ;
}
void update_set(ll x,ll l,ll r,ll L,ll R,ll k){
if(l<=L && R<=r){
tr[x].maxn=k;
tr[x].lazy_max=0;
tr[x].lazy_set=k;
return ;
}
pushDown(x);
ll m=L+R>>1;
if(l<=m) update_set(lc,l,r,L,m,k);
if(r>m) update_set(rc,l,r,m+1,R,k);
pushUp(x);
return ;
}
ll ask(ll x,ll l,ll r,ll L,ll R){
if(l<=L && R<=r) return tr[x].maxn;
pushDown(x);
ll m=L+R>>1;
ll ans=-1e18;
if(l<=m) ans=max(ans,ask(lc,l,r,L,m));
if(r>m) ans=max(ans,ask(rc,l,r,m+1,R));
return ans;
}
int main(){
n=read();
q=read();
for(int i=1;i<=n;i++) arr[i]=read();
build(1,1,n);
for(int i=1;i<=n*4;i++) tr[i].lazy_set=-1e18;
while(q--){
op=read();
l=read();
r=read();
if(op==1){
k=read();
update_set(1,l,r,1,n,k);
}
else if(op==2){
k=read();
update_add(1,l,r,1,n,k);
}
else{
cout<<ask(1,l,r,1,n)<<endl;
}
}
return 0;
}
多区间合并线段树
P4513 小白逛公园
题目简述
给定一个序列,每次有两次操作:
1 l r
求区间 $[l,r]$ 的最大子段和
2 p s
把第 $p$ 个元素修改为 $s$
思路
嗯,单点修改就不说了,直接秒了,但是 $\text{update}$ 不能直接套模板,要看看最大子段和的求法。
那么最大子段和怎么求呢?我们先看一下基本的结构体:
struct node{
ll l,r,val,lmax,rmax,sum;
}tr[N<<2];
其中 $sum$ 就是区间和,$lmax$ 是左儿子的最大子段和,$rmax$ 相反,最后 $val$ 为区间 $[l,r]$ 的最大子段和。
那么你可能会疑惑:直接求最大子段和不就行了吗,为什么还要求左右两边的呢?
因为我们可以发现,区间 $[l,r]$ 的最大子段和不等于 $[l,mid]$ 的最大子段和加上 $[mid+1,r]$ 的最大子段和,因为还有跨越两个区间的最大子段和,所以我们要创这两个变量。
那么我们要怎么更新呢?我们看下面的函数:
node operator + (node x,node y){
node ans;
ans.l=x.l;
ans.r=y.r;
ans.val=max(x.rmax+y.lmax,max(x.val,y.val));
ans.lmax=max(x.lmax,x.sum+y.lmax);
ans.rmax=max(y.rmax,y.sum+x.rmax);
ans.sum=x.sum+y.sum;
return ans;
}
void pushUp(ll x){
tr[x]=tr[lc]+tr[rc];
}
这里我直接重定义了加号,免得再写到函数里面。
我们定义 $v_{[l,r]}$ 表示区间 $[l,r]$ 的最大子段和,那么这里区间最大子段和的更新就是 $v_{[l,mid]} + v_{[mid+1,r]}$ 了,这里我还取了个最大值,以防万一,但是也可以不要。
然后来到重点,对于 $v_{[l,mid]}$ 的更新,我们发现存在两种情况:第一种就是最大子段和都在左边区间,也就是 $ans$ 的 $v_{[l,mid]}$;第二种则是左边都有,右边有一些,所以是上面那一坨(不是)。
然后对于 $v_{[mid+1,r]}$ 的更新,和上面都差不多,最后更新一下区间和即可。
#include<bits/stdc++.h>
#define ll long long
#define lc x<<1
#define rc x<<1|1
using namespace std;
const int N=5e5+5;
struct node{
ll l,r,val,lmax,rmax,sum;
}tr[N<<2];
ll n,q,op,x,y;
ll arr[N];
node operator + (node x,node y){
node ans;
ans.l=x.l;
ans.r=y.r;
ans.val=max(x.rmax+y.lmax,max(x.val,y.val));
ans.lmax=max(x.lmax,x.sum+y.lmax);
ans.rmax=max(y.rmax,y.sum+x.rmax);
ans.sum=x.sum+y.sum;
return ans;
}
void pushUp(ll x){
tr[x]=tr[lc]+tr[rc];
}
void build(ll x,ll l,ll r){
tr[x].l=l;
tr[x].r=r;
if(l==r){
tr[x].val=tr[x].sum=arr[l];
tr[x].lmax=tr[x].rmax=arr[l];
return ;
}
ll m=l+r>>1;
build(lc,l,m);
build(rc,m+1,r);
pushUp(x);
return ;
}
void modify(ll x,ll p,ll k){
if(tr[x].l==tr[x].r){
tr[x].lmax=tr[x].rmax=k;
tr[x].sum=tr[x].val=k;
return ;
}
ll m=tr[x].l+tr[x].r>>1;
if(p<=m) modify(lc,p,k);
else modify(rc,p,k);
pushUp(x);
return ;
}
node ask(ll x,ll l,ll r){
if(l<=tr[x].l && tr[x].r<=r) return tr[x];
ll m=tr[x].l+tr[x].r>>1;
if(r<=m) return ask(lc,l,r);
if(l>m) return ask(rc,l,r);
return ask(lc,l,r)+ask(rc,l,r);
}
int main(){
cin>>n>>q;
for(int i=1;i<=n;i++) cin>>arr[i];
build(1,1,n);
while(q--){
cin>>op>>x>>y;
if(op==1){
if(x>y) swap(x,y);
cout<<ask(1,x,y).val<<endl;
}
else modify(1,x,y);
}
return 0;
}
推荐题目
P10463 Interval GCD
P2894 [USACO08FEB] Hotel G
本文来自博客园,作者:一只小何,转载请注明原文链接:https://www.cnblogs.com/Cristuff/p/19128432