线段树学习笔记
线段树是一种用来维护区间的数据结构,如果所维护的条件满足加法性质那么就可以使用 STG(线段树)进行维护。
例如区间求和、区间最值、区间 \(\gcd\) 等。当然,也可以维护线段一类,比如扫描线。
线段树是一种完全二叉树,对于一个节点 \(u\),如果它有,那么有 \(ls = u * 2\),\(rs = u * 2 + 1\)。其右儿子维护的区间分别是 \([l,mid]\) 与 \([mid+1,r]\)。
他的操作格式如下:
pushup()将左右子节点的区间合并为一个区间,将其赋给当前节点,并同时维护所有需的信息,格式如下:
代码
void pushup(int u){
t[u]=sum(t[ls],t[rs]);
}
maketag()将区间按一定方式进行修改并为区间打上标记,标记是为了更好、更快的进行修改,格式如下:
代码
void maketag(int u,int l,int r,int k){
t[u]=do_sth(k);
tag[u]=do_sth(k);
}
pushdown将该节点的标记下传给子节点,并取消标记,格式如下:
代码
void pushdown(int u,int l,int r){
if(tag[u]!=null){
maketag(ls,l,mid,tag[u]);
maketag(rs,mid+1,r,tag[u]);
tag[u]=null;
}
}
change()区间修改,并维护需要的值,同时下传所经结点的标记,格式如下:
代码
void change(int u,int l,int r,int L,int R,int k){
if(L<=l&&r<=R){
maketag(u,l,r,k);
return;
}
pushdown(u,l,r);
if(L<=mid) change(ls,l,mid,L,R,k);
if(R>mid) change(rs,mid+1,r,L,R,k)l;
pushup(u);
}
query()区间查询,查询一个区间所维护的值,同时下传所经结点的标记,格式如下:
代码
int query(int u,int l,int r,int L,int R){
if(L<=l&&r<=R) return t[u];
pushdown(u,l,r);
int s=null;
if(L<=mid) s+=query(ls,l,mid,L,R,k);
if(R>mid) s+=query(rs,mid+1,r,L,R,k)l;
return s;
}
于是我们就可以打掉板子题了:
\(P3373\) 线段树2 :
AC代码
#include<iostream>
#define ll long long
#define maxn 100001
using namespace std;
ll tag[maxn<<2],mul[maxn<<2];
ll w[maxn<<2];
ll a[maxn];
ll p;
void pu(int u){
w[u]=(w[u<<1]+w[(u<<1)+1])%p;
}
void maketag(int u,int len,ll x,ll y){
w[u]=(w[u]*y+x*len)%p;
mul[u]=(mul[u]*y)%p;
tag[u]=(tag[u]*y+x)%p;
}
void pd(int u,int l,int r){
int m=(l+r)>>1;
maketag(u<<1,m-l+1,tag[u],mul[u]);
maketag((u<<1)+1,r-m,tag[u],mul[u]);
tag[u]=0;
mul[u]=1;
}
ll query(int u,int L,int R,int l,int r){
if(l<=L&&R<=r)
return w[u];
else if(!(L>r||R<l)){
int m=(L+R)>>1;
pd(u,L,R);
return (query(u<<1,L,m,l,r)+query((u<<1)+1,m+1,R,l,r))%p;
}
else return 0;
}
void update(int u,int L,int R,int l,int r,ll x,ll y){
if(l<=L&&R<=r)
maketag(u,R-L+1,x,y);
else if(!(L>r||R<l)){
int m=(L+R)>>1;
pd(u,L,R);
update(u<<1,L,m,l,r,x,y);
update((u<<1)+1,m+1,R,l,r,x,y);
pu(u);
}
}
void build(int u,int l,int r){
mul[u]=1;
if(l==r){w[u]=a[l];return;}
int m=(l+r)>>1;
build((u<<1),l,m);
build((u<<1)+1,m+1,r);
pu(u);
}
int main(){
int n,m;
scanf("%d%d%lld",&n,&m,&p);
for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
build(1,1,n);
for(int i=1;i<=m;i++){
int t,x,y;
ll k;
scanf("%d",&t);
if(t==1){
scanf("%d%d%lld",&x,&y,&k);
update(1,1,n,x,y,0,k);
}
else if(t==2){
scanf("%d%d%lld",&x,&y,&k);
update(1,1,n,x,y,k,1);
}
else if(t==3){
scanf("%d%d",&x,&y);
printf("%lld\n",query(1,1,n,x,y));
}
}
return 0;
}
接下来看几道例题:
\(P4145\) 上帝造题的七分钟 2 / 花神游历各国
这是一道很明显的线段树题,题目概述如下:
需要编写一个数据结构并支持以下操作:
- 对 \([l,r]\) 区间进行开方运算并向下取整。
- 求 \(\sum\limits_{i=l}\limits^{r}a[i]\) 并输出。
我们可以用线段树来维护区间求和,然后暴力单点开方。但这样会T。
区间求和已经没法再优化了(至少不在我的能力范围之内),考虑优化开方操作。
由于区间开方无法通过打标记的方式优化,所以我们只能从开方的性质入手。
我们已知 \(\sqrt{1}=1\),那么我们可以利用这个性质来优化线段树。
当一个区间的最大值都小于等于 \(1\) 时,这个区间内就只有 \(1\)(开方不可能开出负数),这是我们就可以直接放弃对这一个区间的修改。
由于一个小于等于 \({10}^{12}\) 的数最多被开方 \(6\) 次后向下取整就变成了 \(1\),而每一次查询的复杂度都为 \(O(n\log_2{n})\)。
至于修改操作,在最坏情况下复杂度为 \(O(n\log_2{n})\),总复杂度为 \(O(6n\log_2{n})\)。
那么这种做法的总复杂度即为 \(O(m\log_2{n}+6n\log_2{n})\)。对于本题是能够通过的。
AC代码
#include<iostream>
#include<cmath>
#define ll long long
using namespace std;
const int N(1e5+3);
int n,m;
struct node{
ll mx,sum;
} t[N<<2];
#define ls (u<<1)
#define rs (u<<1|1)
#define mid ((l+r)>>1)
void pushup(int u){
t[u].sum=t[ls].sum+t[rs].sum;
t[u].mx=max(t[ls].mx,t[rs].mx);
}
void build(int u,int l,int r){
if(l==r){
scanf("%lld",&t[u].mx);t[u].sum=t[u].mx;
return ;
}
build(ls,l,mid);build(rs,mid+1,r);
pushup(u);
}
void change(int u,int l,int r,int L,int R){
if(l==r){
t[u].mx=sqrt(t[u].mx);t[u].sum=sqrt(t[u].sum);
return ;
}
if(L<=mid&&t[ls].mx>1) change(ls,l,mid,L,R);
if(R>mid&&t[rs].mx>1) change(rs,mid+1,r,L,R);
pushup(u);
}
ll qry(int u,int l,int r,int L,int R){
if(L<=l&&r<=R) return t[u].sum;
ll s=0;
if(L<=mid) s+=qry(ls,l,mid,L,R);
if(R>mid) s+=qry(rs,mid+1,r,L,R);
return s;
}
int main(){
#ifdef ytxy
freopen("in.txt","r",stdin);
#endif
scanf("%d",&n);
build(1,1,n);
scanf("%d",&m);
while(m--){
int k,l,r;
scanf("%d%d%d",&k,&l,&r);
if(l>r) swap(l,r);
if(k==0){
change(1,1,n,l,r);
}
else{
printf("%lld\n",qry(1,1,n,l,r));
}
}
}
\(P2572\) [SCOI2010] 序列操作:
题面很清楚,不多赘述。
很有挑战性的一道题,看了小粉兔的题解后恍然大悟,原来线段树还能这么写……
这一题就难在 maketag() 与 pushup() 不过我们可以用另一种方式完成 pushup() 的工作。
我们可以考虑 \(t[u]=merge(t[ls],t[rs])\),即合并 \(u\) 的儿子区间赋给 \(u\) 节点(\(t[u]\) 表示当前节点结构体)。
而 merge() 的代码实现并不难:
inline node merge(node x,node y){
return node(//构造函数
x.w+y.w,x.b+y.b,
(x.b?x.lw:x.w+y.lw),(x.w?x.lb:x.b+y.lb),
(y.b?y.rw:y.w+x.rw),(y.w?y.rb:y.b+x.rb),
max(max(x.mw,y.mw),x.rw+y.lw),
max(max(x.mb,y.mb),x.rb+y.lb)
);
}
区间赋值并不难,区间 \(xor\) 才是最难的。我们可以考虑同时维护一个区间内的 \(0\) 的总数、最长连续、前缀长度、后缀长度,对于区间内的 \(1\) 也维护同样的值。这样一来在进行区间 \(xor\) 时就只需要对维护的 \(0\),\(1\) 对应的值进行交换即可。
那么我们可以开两个数组来保存赋值标记与异或标记。打标记就变得很简单了:
inline void mt(int u,int ty,int len){
if(ty==0) t1[u]=0,t2[u]=0,
t[u]=node(0,len,0,len,0,len,0,len);
else if(ty==1) t1[u]=1,t2[u]=0,
t[u]=node(len,0,len,0,len,0,len,0);
else if(ty==2) t2[u]^=1,
swap(t[u].w,t[u].b),swap(t[u].lw,t[u].lb),
swap(t[u].rw,t[u].rb),swap(t[u].mw,t[u].mb);
}
在查询时我们可以将 \([l,r]\) 区间用先前的 merge() 合并出来,这样就可以将 \(3,4\) 操作合在一起,在输出时进行判断。
总的复杂的即为 \(O(m\log{n})\)。
AC代码
#include<iostream>
using namespace std;
const int N(1e5+5);
int n,m;
struct node{
int w,b,lw,lb,rw,rb,mw,mb;
node(int w=0,int b=0,int lw=0,int lb=0,//构造函数方便合并区间
int rw=0,int rb=0,int mw=0,int mb=0):
w(w),b(b),lw(lw),lb(lb),
rw(rw),rb(rb),mw(mw),mb(mb){}
} t[N<<2];
int t1[N<<2],t2[N<<2];
inline node merge(node x,node y){//合并区间
return node(
x.w+y.w,x.b+y.b,
(x.b?x.lw:x.w+y.lw),(x.w?x.lb:x.b+y.lb),//考虑整个子区间都是1或0的情况
(y.b?y.rw:y.w+x.rw),(y.w?y.rb:y.b+x.rb),
max(max(x.mw,y.mw),x.rw+y.lw),//考虑左、右两节点与横跨两个子区间的
max(max(x.mb,y.mb),x.rb+y.lb)//同上
);
}
inline void mt(int u,int ty,int len){
if(ty==0) t1[u]=0,t2[u]=0,
t[u]=node(0,len,0,len,0,len,0,len);
else if(ty==1) t1[u]=1,t2[u]=0,
t[u]=node(len,0,len,0,len,0,len,0);
else if(ty==2) t2[u]^=1,
swap(t[u].w,t[u].b),swap(t[u].lw,t[u].lb),//交换0与1的对应值
swap(t[u].rw,t[u].rb),swap(t[u].mw,t[u].mb);
}
#define ls (u<<1)
#define rs (u<<1|1)
#define mid (l+r>>1)
void pushdown(int u,int l,int r){
if(t1[u]!=-1)mt(ls,t1[u],mid-l+1),mt(rs,t1[u],r-mid);
if(t2[u]) mt(ls,2,mid-l+1),mt(rs,2,r-mid);
t1[u]=-1,t2[u]=0;
}
void build(int u,int l,int r){
t1[u]=-1;
if(l==r){
int a;scanf("%d",&a);
t[u]=node(a,a^1,a,a^1,a,a^1,a,a^1);return ;
}
build(ls,l,mid);build(rs,mid+1,r);
t[u]=merge(t[ls],t[rs]);
}
void change(int u,int l,int r,int L,int R,int ty){
if(r<L||R<l) return ;
if(L<=l&&r<=R){mt(u,ty,r-l+1);return ;}
pushdown(u,l,r);
change(ls,l,mid,L,R,ty);change(rs,mid+1,r,L,R,ty);
t[u]=merge(t[ls],t[rs]);
}
node qry(int u,int l,int r,int L,int R){
if(r<L||R<l) return node();
if(L<=l&&r<=R) return t[u];
pushdown(u,l,r);
return merge(qry(ls,l,mid,L,R),qry(rs,mid+1,r,L,R));
}
int main(){
#ifdef ytxy
freopen("in.txt","r",stdin);
#endif
scanf("%d%d",&n,&m);
build(1,1,n);
while(m--){
int opt,l,r;
scanf("%d%d%d",&opt,&l,&r);
//题目有点坑,测试点的区间是[0,n-1],所以需要加1来方便线段树维护
if(opt<3) change(1,1,n,l+1,r+1,opt);
else{
node x=qry(1,1,n,l+1,r+1);
printf("%d\n",(opt==3?x.w:x.mw));
}
}
}

浙公网安备 33010602011771号