[基础例题] 线段树
背景:最近czy刚讲了segmenttree,整理例题。
线段树操作要考虑:
0.记录必要数值
1.支持区间合并
2.标记覆盖
3.标记下放(cpu监控),数值上推(楼房重建)
4.保证log
...........
T1:
n 个数, qqq 次操作
操作0 x y
把 Ax 修改为 y
操作1 l r
询问区间 [l,r] 的最大子段和
即:带修最长子段和。
分析:
显然,一个区间的最大子段和,假如我们像线段树操作一样,把它劈成两半,这个子段和,要么是从ls左端点开始的一段,要么是rs右端点开始一段,要么是ls右端点开始一段和rs左端点开始一段,或者是ls/rs最大子段和。
所以,线段树里记录sum,lmx,rmx,zui 分别是总和,左端点开始最大子段和,右端点开始的最大子段和,区间最大子段和。
难点在于query,其实本质就是根据已有区间的情况,重组一个新的区间情况。
czy们的做法是,传一个结构体,里面记录新拆分的区间的sum,lmx,rmx,zui。比较好的做法,体现了重组新区间的本质。
我的做法是:每次传入c,表示要处理出区间的左端点开始最大值、右端点开始最大值、或者最大子段和。
注意的是,如果要传入查询最大子段和(c=3),左端点开始(c=1),或者右端点开始时,
对于区间分割的情况不同(L<=mid/mid<R),根据当前的c不同,传入的c也是不同的。
否则会出现的错误是:把查询区间拆成两半,结果可能取了两边儿子的各自最大子段和做和,但是可能这两个子区间并不是连续的。
详见代码:
#include<bits/stdc++.h> using namespace std; const int N=50000+10; typedef long long ll; const ll inf=2e18+1; struct node{ ll sum,lmx,rmx; ll zui; }t[4*N]; int n,m; ll a[N]; void pushup(int x) { int ls=x<<1,rs=x<<1|1; t[x].sum=t[ls].sum+t[rs].sum; t[x].lmx=max(t[ls].sum+t[rs].lmx,t[ls].lmx); t[x].rmx=max(t[rs].sum+t[ls].rmx,t[rs].rmx); t[x].zui=max(t[rs].zui,max(t[ls].zui,t[ls].rmx+t[rs].lmx)); } void build(int x,int l,int r) { if(l==r){ t[x].lmx=t[x].rmx=t[x].sum=a[l]; t[x].zui=a[l]; return; } int mid=(l+r)>>1; build(x<<1,l,mid); build(x<<1|1,mid+1,r); pushup(x); } void ch(int x,int l,int r,int to,ll val) { if(l==r){ t[x].sum=t[x].rmx=t[x].lmx=val; t[x].zui=val; return; } int mid=(l+r)>>1; if(to<=mid) ch(x<<1,l,mid,to,val); else ch(x<<1|1,mid+1,r,to,val); pushup(x); } ll query(int x,int l,int r,int L,int R,int c) { if(L<=l&&r<=R){ if(c==3) return t[x].zui; if(c==1) return t[x].lmx; if(c==2) return t[x].rmx; if(c==4) return t[x].sum; } int mid=(l+r)>>1; ll ret=-inf; if(L<=mid&&mid<R){//注意这里的判断 if(c==3){ ret=max(ret,query(x<<1,l,mid,L,R,3)); ret=max(ret,query(x<<1|1,mid+1,r,L,R,3)); ret=max(ret,query(x<<1,l,mid,L,R,2)+query(x<<1|1,mid+1,r,L,R,1)); } else if(c==2){ ret=max(ret,query(x<<1,l,mid,L,R,2)+t[x<<1|1].sum); ret=max(ret,query(x<<1|1,mid+1,r,L,R,2)); } else{ ret=max(ret,t[x<<1].sum+query(x<<1|1,mid+1,r,L,R,1)); ret=max(ret,query(x<<1,l,mid,L,R,1)); } } else if(L<=mid){ret=max(ret,query(x<<1,l,mid,L,R,c));} else if(mid<R){ret=max(ret,query(x<<1|1,mid+1,r,L,R,c));} return ret; } int main() { scanf("%d",&n); for(int i=1;i<=n;i++) scanf("%lld",&a[i]); build(1,1,n); scanf("%d",&m); int op,l,r,go; ll v; for(int i=1;i<=m;i++){ scanf("%d",&op); if(op){ scanf("%d%d",&l,&r); printf("%lld\n",query(1,1,n,l,r,3)); } else{ scanf("%d%lld",&go,&v); ch(1,1,n,go,v); } } return 0; }
T2:
题目大意:
给定长度为 𝑁 的实数序列 𝐴𝑖(1 ≤ 𝑖 ≤ 𝑁), 你需要在数列上进行两类操作:
1. 把 𝑙 ≤ 𝑖 ≤ 𝑟 中的每个 𝐴𝑖 加上实数 𝑣。
2. 求 𝑙 ≤ 𝑖 ≤ 𝑟 中 cos(𝐴𝑖) 的和。
分析:
线段树中,我们必然不会记录每个点的cos值,但是每个数加上v,又不能像加法一样直接加上去,但是我们肯定还要迅速找到区间的cos和,所以怎么办?
我们一定会记录一个cos总和
考虑两角和余弦公式:cos(a+b)=cosa*cosb-sina*sinb
所以,每个数加上一个b,新产生的cos总和=旧cossum*cosb+旧sinsum*sinb
再利用正弦公式,同样可以更新sin总和
所以我们一个区间里要记录cos和,sin和。
值得记住的是:
可以直接使用cos、sin以及反三角acos,但是,cos(a),其中会把a看做弧度制,所以如果输入的是角度,*PI/180就好。
PI=acos(-1.0)
eps要注意~!!!!否则可能-0.000
#include<cmath> #include<iostream> #include<cstdlib> #define PI acos(-1.0) using namespace std; int n; double s,c; double eps=0.0000001; int main() { scanf("%d",&n); s=sin((PI*n)/180.0); c=cos((PI*n)/180.0); printf("sin:%.3lf cos:%.3lf",s+eps,c+eps); return 0; }
不过这题就是弧度制,所以代码:
#include<bits/stdc++.h> using namespace std; const int N=200000+10; struct node{ double ssum,csum; double laz; }t[4*N]; int n,m; int T; double a[N]; void pushup(int x) { t[x].ssum=t[x<<1].ssum+t[x<<1|1].ssum; t[x].csum=t[x<<1].csum+t[x<<1|1].csum; } void build(int x,int l,int r) { if(l==r){ t[x].ssum=sin(a[l]); t[x].csum=cos(a[l]); t[x].laz=0.00000; return; } t[x].laz=t[x].ssum=t[x].csum=0.00000; int mid=(l+r)>>1; build(x<<1,l,mid); build(x<<1|1,mid+1,r); pushup(x); } void pushdown(int x) { double si=sin(t[x].laz); double co=cos(t[x].laz); for(int i=0;i<=1;i++) { int son=x<<1|i; t[son].laz+=t[x].laz; double ssin=t[son].ssum*co+t[son].csum*si; double scos=t[son].csum*co-t[son].ssum*si; t[son].ssum=ssin; t[son].csum=scos; } t[x].laz=0.00000; } void add(int x,int l,int r,int L,int R,double c) { if(L<=l&&r<=R) { double si=sin(c); double co=cos(c); double ssin=t[x].ssum*co+t[x].csum*si; double scos=t[x].csum*co-t[x].ssum*si; t[x].ssum=ssin; t[x].csum=scos; t[x].laz+=c; return; } pushdown(x); int mid=(l+r)>>1; if(L<=mid) add(x<<1,l,mid,L,R,c); if(mid<R) add(x<<1|1,mid+1,r,L,R,c); pushup(x); } double query(int x,int l,int r,int L,int R) { if(L<=l&&r<=R){ return t[x].csum; } pushdown(x); int mid=(l+r)>>1; double ret=0.0000; if(L<=mid) ret+=query(x<<1,l,mid,L,R); if(mid<R) ret+=query(x<<1|1,mid+1,r,L,R); return ret; } int main() { scanf("%d",&T); for(int o=1;o<=T;o++) { printf("Case #%d:\n",o); scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) scanf("%lf",&a[i]); build(1,1,n); int op,l,r; double v; while(m--) { scanf("%d",&op); if(op==1){ scanf("%d%d%lf",&l,&r,&v); add(1,1,n,l,r,v); } else{ scanf("%d%d",&l,&r); printf("%.3lf\n",query(1,1,n,l,r)); } } } return 0; }
T3:
动态gcd:
给你一棵 n 个结点的树,每一个结点上有一个正整数权值,其中第 i 个结点上的权值是 v[i] 。
你的程序必须维护 2 种操作:
一,表示为 F u v 的查找操作:找出从 u 到 v 的唯一路径上所有点权值的最大公约数。(包括 u 和 v)
二,表示为 C u v d 的修改操作:将从 u 到 v 的唯一路径上所有点权值加上 d。(包括 u 和 v)
分析:
这题还是比较综合的。
区间加v,gcd怎么维护???
一脸懵。。。
gcd(a,b)=gcd(a,a-b)
扩展:gcd(a,b,c,d,...n) = gcd( a, b-a ,c-b,...n-m)
这样,我们巧妙地利用性质,将区间加,变成了两个单点加,l+=v,r+1-=v;!!!!
这样,每次我们线段树到了叶子节点,才跟新值,pushdown拜拜,pushup直接gcd就好了。
这种处理值得注意,相似的,算法例如树状数组,例题例如借教室
1.树剖dfn化为dfn区间。
2.线段树维护区间最大公约数(每个数都是一个差分后的数)
3.根据输入,类似lca过程,边向lca靠拢,边更新答案。每次更新(dfn[top[x],dfn[x])这一段区间
4.不能直接找[l,r]区间gcd,因为这维护的是差分数组,最左边的点还涉及到l-1的值。并不对。应该是gcd(al,gcd[l+1,r])合适。我们还要知道al,
所以,差分树状数组,随时更新,以便logn查找必要的值。
注意事项:
1.差分可能出现负数,但是负数可以直接当做正数处理。
2.更新线段树l,r+1时,注意,新的差值并不能直接在叶子节点上的值加上v,因为,这可能本来是个负值,但是变成了正值,加上一个较小正数,实际上的值应该变小(绝对值变小)
但是直接加会变大,就WA了。所以,每要更新一个l,先把l-1值取出来,把l值取出来,l值+=v,与l-1值差分,取绝对值,然后放进去。
3.注意,dfn序列,编号对应的并不是原来的1~n,想清楚自己要处理的是什么。有了树剖,开始的树已经变成了dfn链,我们的记录,不论线段树,还是树状数组,都要维护dfn序列。
代码:194行
#include<bits/stdc++.h> using namespace std; const int N=50000+10; int n,m; struct node{ int g; }t[4*N]; struct bb{ int nxt,to; }bian[2*N]; int hd[N],cnt; int tot; int f[N],b[N]; int p[N]; void add(int x,int y){bian[++cnt].nxt=hd[x];bian[cnt].to=y;hd[x]=cnt;} int gcd(int x,int y){ if(x<y) swap(x,y); while(y){int d=x%y;x=y,y=d;} return x; } //------------------------------杂务------------------------------- void jia(int x,int val){for(;x<=n;x+=x&(-x)){f[x]+=val;}} int ask(int x) {int ret=0;for(;x>=1;x-=x&(-x)) ret+=f[x];return ret;} //------------------------------压行树状数组------------------------- int dfn[N],top[N],son[N],dep[N],fa[N],siz[N],fdfn[N]; int dfn2[N]; void dfs1(int x,int f,int d) { dep[x]=d; siz[x]=1; for(int i=hd[x];i;i=bian[i].nxt) { int y=bian[i].to; if(y==f) continue; fa[y]=x; dfs1(y,x,d+1); siz[x]+=siz[y]; if(siz[y]>siz[son[x]]){ son[x]=y; } } } void dfs2(int x) { dfn[x]=++tot; fdfn[tot]=x; if(!top[x]) top[x]=x; if(son[x]) {top[son[x]]=top[x],dfs2(son[x]);} for(int i=hd[x];i;i=bian[i].nxt) { int y=bian[i].to; if(y==fa[x]||y==son[x]) continue; dfs2(y); } dfn2[x]=tot; } //----------------------------------------以上树剖-------------------------------- void pushup(int x) { t[x].g=gcd(t[x<<1].g,t[x<<1|1].g); } void build(int x,int l,int r) { if(l==r){ t[x].g=abs(b[l]); return; } int mid=(l+r)>>1; build(x<<1,l,mid); build(x<<1|1,mid+1,r); pushup(x); } void update(int x,int l,int r,int to,int c) { if(l==r){ t[x].g=c; return; } int mid=(l+r)>>1; if(to<=mid) update(x<<1,l,mid,to,c); else if(to>mid) update(x<<1|1,mid+1,r,to,c); pushup(x); } int query(int x,int l,int r,int L,int R) { if(L<=l&&r<=R){ return t[x].g; } int mid=(l+r)>>1; int ret; if(L<=mid&&mid<R){ ret=gcd(query(x<<1,l,mid,L,R),query(x<<1|1,mid+1,r,L,R)); } else if(L<=mid){ ret=query(x<<1,l,mid,L,R); } else if(mid<R) ret=query(x<<1|1,mid+1,r,L,R); return ret; } //----------------------------------------以上线段树------------------------------ int work1(int x,int y) { int ans=-1; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); if(dfn[top[x]]==dfn[x]){ if(ans==-1) ans=ask(dfn[x]); else ans=gcd(ans,ask(dfn[x])); } else{ int now=gcd(ask(dfn[top[x]]),query(1,1,n,dfn[top[x]]+1,dfn[x])); if(ans==-1) ans=now; else ans=gcd(ans,now); } x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); int hew; if(x!=y) hew=gcd(ask(dfn[x]),query(1,1,n,dfn[x]+1,dfn[y])); else hew=ask(dfn[x]); if(ans==-1) ans=hew; else ans=gcd(ans,hew); return ans; } void work2(int x,int y,int v) { while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y);//dfn[x]~dfn[top[x] int las=ask(dfn[top[x]]-1),now=ask(dfn[top[x]]); int chan=abs(now+v-las); update(1,1,n,dfn[top[x]],chan); las=ask(dfn[x]),now=ask(dfn[x]+1); chan=abs(now-v-las); update(1,1,n,dfn[x]+1,chan); jia(dfn[top[x]],v);jia(dfn[x]+1,-v); x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); int las=ask(dfn[x]-1),now=ask(dfn[x]); int chan=abs(now+v-las); update(1,1,n,dfn[x],chan); las=ask(dfn[y]),now=ask(dfn[y]+1); chan=abs(now-v-las); update(1,1,n,dfn[y]+1,chan); jia(dfn[x],v),jia(dfn[y]+1,-v); } //-------------------------------------------操作处理--------------------------- int main() { scanf("%d",&n); int x,y; for(int i=1;i<=n-1;i++) { scanf("%d%d",&x,&y); x++,y++; add(x,y);add(y,x); } for(int i=1;i<=n;i++) scanf("%d",&p[i]); dfs1(1,0,1); dfs2(1); for(int i=1;i<=n;i++){ b[i]=p[fdfn[i]]-p[fdfn[i-1]]; jia(i,b[i]); } build(1,1,n); scanf("%d",&m); char q; int val; while(m--) { scanf("%c",&q);//读换行符 scanf("%c",&q); if(q=='F') { scanf("%d%d",&x,&y); x++,y++; printf("%d\n",work1(x,y)); } else{ scanf("%d%d%d",&x,&y,&val); x++,y++; work2(x,y,val); } } return 0; }
T4:
等差数列:
你需要维护一个数列,每次操作如下
1 l r item diff
第l个元素加上item,第l+1个元素加上item+diff,第l+2个元素加上item+diff+diff,诸如此类一直到r
2 l r
询问[l,r]的和
分析:
维护是很好想到的。
我是区间维护sum,标记laz,公差标记d,没了。
坑坑坑坑坑坑坑坑坑坑坑坑坑坑坑坑坑坑坑坑坑坑坑坑坑坑坑坑的是:
还是当区间分割向下找区间的时候,右子区间首项其实已经变了。
必须从原始区间变过去,a0->a0+d*(mid-L+1) 对于L<=mid&&mid<=R,并且,为了保证下次的L是没有算过的,L->mid+1和区间左端点一致。
WA了无数次。对拍才找到。。。。
#include<bits/stdc++.h> using namespace std; const int N=1e6+20; typedef long long ll; int cnt; int n,m; int a[N]; ll sum[N]; struct node{ ll sum,d; ll laz; }t[4*N]; void pushup(int x){ t[x].sum=t[x<<1].sum+t[x<<1|1].sum; } void pushdown(int x,int l,int r) { int ls=x<<1,rs=x<<1|1; int mid=l+r>>1; t[ls].sum+=t[x].laz*(mid-l+1)+(mid-l+1)*(mid-l)*t[x].d/2; t[rs].sum+=(t[x].laz+(mid-l+1)*t[x].d)*(r-mid)+(r-mid)*(r-mid-1)*t[x].d/2; t[ls].d+=t[x].d;t[rs].d+=t[x].d; t[ls].laz+=t[x].laz; t[rs].laz+=t[x].laz+(mid-l+1)*t[x].d; t[x].d=0;t[x].laz=0; } void add(int x,int l,int r,int L,int R,ll a0,ll d){ int len=r-l+1; if(L<=l&&r<=R){ t[x].sum+=len*a0+len*(len-1)*d/2; t[x].d+=d; t[x].laz+=a0; return; } pushdown(x,l,r); int mid=(l+r)>>1; if(L<=mid&&mid<R){ add(x<<1,l,mid,L,mid,a0,d); add(x<<1|1,mid+1,r,mid+1,R,a0+(mid+1-L)*d,d);//一定注意。一般是不用L->mid+1的,但是这次必须用。这个L是相对的。自己画图体会。 } else if(L<=mid) add(x<<1,l,mid,L,R,a0,d); else if(mid<R) add(x<<1|1,mid+1,r,L,R,a0,d); pushup(x); } ll query(int x,int l,int r,int L,int R){ if(L<=l&&r<=R){ return t[x].sum; } pushdown(x,l,r); int mid=l+r>>1; ll ret=0; if(L<=mid) ret+=query(x<<1,l,mid,L,R); if(mid<R) ret+=query(x<<1|1,mid+1,r,L,R); return ret; } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) scanf("%d",&a[i]),sum[i]=sum[i-1]+a[i]; int op,l,r; ll x,dif; for(int i=1;i<=m;i++){ cnt++; scanf("%d",&op); if(op==1){ scanf("%d%d%lld%lld",&l,&r,&x,&dif); add(1,1,n,l,r,x,dif); } else{ scanf("%d%d",&l,&r); printf("%lld\n",sum[r]-sum[l-1]+query(1,1,n,l,r)); } } return 0; }