树链剖分
写在前面
某位菜鸡花了半天的时间打完了树剖,又用了半天的时间di了无数个bug后,终于获得了MLE,RE并存的喜人成绩,最终在 \(ljc\) 大佬的指点下才发现原来是函数void写成int了并且没写返回值 T_T
正题
懒得自己写一篇博客了,就把我当时学习树剖时的一篇写的非常好的博客拿出来欣赏吧
博客原文
几条重要的规则
- 一个字树内的dfs序连续,线段树维护3,4操作
- 节点数多的叫做重儿子,重儿子到父亲节点的边叫做重边
- 每一条链dfs序连续(先遍历重儿子)
- 跳到同一条重链上深度小的即为最近公共祖先
dfs1():
统计每个节点的深度 \(deep[ ]\) ,每个节点父亲 \(fa[ ]\) ,每个节点子树大小 \(size[ ]\) ,重儿子编号 \(mson[]\)
dfs2():
每个旧节点新编号 \(id[]\) ,每个新节点的旧编号 \(di[]\) ,新编号的值 \(newa[]\) ,每个节点所在重链顶端 \(top[]\)
警示后人
一定要处理好新编号和旧编号的关系,最好采用( \(id\) 和 \(di\) )双重索引,注意你的变量是哪一次dfs()维护的,注意新旧节点的转换
代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+5;
int n,m,r,p,cnt,ans;
int dfn[N],fa[N],siz[N],dep[N],son[N],top[N],to[N],a[N];
vector<int>b[N];
void dfs1(int u,int f){
fa[u]=f;
siz[u]=1;
dep[u]=dep[f]+1;
int mx=0;
for(int v:b[u]){
if(v==f) continue;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>mx) mx=siz[v],son[u]=v;
}
}
void dfs2(int u,int tp){
top[u]=tp;
dfn[u]=++cnt;
to[cnt]=u;
if(son[u]) dfs2(son[u],tp);
for(int v:b[u]){
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
struct dot{
int x,add;
}tr[N*4];
struct Tree{
void build(int k,int l,int r){
if(l==r){
tr[k]={a[to[l]],0};
return;
}
int mid=(l+r)>>1;
build(k*2,l,mid);
build(k*2+1,mid+1,r);
tr[k].x=(tr[k*2].x+tr[k*2+1].x)%p;
}
void Add(int k,int l,int r,int z){
tr[k].add+=z;
tr[k].add%=p;
tr[k].x+=z*(r-l+1);
tr[k].x%=p;
}
void pushdown(int k,int l,int r,int mid){
if(!tr[k].add) return;
Add(k*2,l,mid,tr[k].add);
Add(k*2+1,mid+1,r,tr[k].add);
tr[k].add=0;
}
void longchange(int k,int l,int r,int x,int y,int z){
if(x<=l&&r<=y){
Add(k,l,r,z);
return;
}
int mid=(l+r)>>1;
pushdown(k,l,r,mid);
if(x<=mid) longchange(k*2,l,mid,x,y,z);
if(y>mid) longchange(k*2+1,mid+1,r,x,y,z);
tr[k].x=(tr[k*2].x+tr[k*2+1].x)%p;
}
int longquery(int k,int l,int r,int x,int y){
if(x<=l&&r<=y){
return tr[k].x;
}
int mid=(l+r)>>1,res=0;
pushdown(k,l,r,mid);
if(x<=mid) res+=longquery(k*2,l,mid,x,y);
if(y>mid) res+=longquery(k*2+1,mid+1,r,x,y);
return res%p;
}
}tree;
void opert(int l,int r,int op,int z){
if(op==0) ans+=tree.longquery(1,1,n,l,r),ans%=p;
else tree.longchange(1,1,n,l,r,z);
}
void lca(int x,int y,int op,int z){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
opert(dfn[top[x]],dfn[x],op,z);
x=fa[top[x]];
}
if(dep[x]<dep[y]) swap(x,y);
opert(dfn[y],dfn[x],op,z);
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>n>>m>>r>>p;
for(int i=1;i<=n;i++){
cin>>a[i];
}
for(int i=1;i<n;i++){
int u,v;
cin>>u>>v;
b[u].push_back(v);
b[v].push_back(u);
}
dfs1(r,0);
dfs2(r,r);
tree.build(1,1,n);
for(int i=1;i<=m;i++){
int op,x,y,z;
ans=0;
cin>>op>>x;
if(op==1){
cin>>y>>z;
lca(x,y,1,z);
}
else if(op==2){
cin>>y;
lca(x,y,0,0);
cout<<ans%p<<'\n';
}
else if(op==3){
cin>>z;
tree.longchange(1,1,n,dfn[x],dfn[x]+siz[x]-1,z);
}
else{
ans=tree.longquery(1,1,n,dfn[x],dfn[x]+siz[x]-1);
cout<<ans%p<<'\n';
}
}
}
金牌导航题目
T1:
作为停了3,4个月信竞回来做的第一道题,一道板子打了3个小时,还让deepseek帮了忙,非常不牛
但是代码相较以前做出了一些优化:
点击查看代码
#include<bits/stdc++.h>
#define dep(x) dot[x].dep
#define top(x) dot[x].top
#define fa(x) dot[x].fa
#define dfn(x) dot[x].dfn
#define mx(x) tr[x].mx
#define sum(x) tr[x].sum
#define val(x) tr[x].val
using namespace std;
const int N=1e5+5;
struct Node{
int dfn,dep,top,son,w,fa,siz;
}dot[N];
struct tree{
int val,sum=0,mx=-1e5;
}tr[N*4];
vector<int>b[N];
int cnt,ans,n,q;
int id[N];
void dfs1(int x,int f){
dot[x].dep=dot[f].dep+1;
dot[x].fa=f;
int maxn=0,maxd=0;
for(int v:b[x]){
if(v==f) continue;
dfs1(v,x);
if(dot[v].siz>maxn){
maxn=dot[v].siz;
maxd=v;
}
dot[x].siz+=dot[v].siz;
}
dot[x].siz++;
dot[x].son=maxd;
return;
}
void dfs2(int x,int topy){
dot[x].dfn=++cnt;//一定要按照树链剖分的顺序排dfn,所以要在dfs2中
id[cnt]=x;
dot[x].top=topy;
if(!dot[x].son) return;
dfs2(dot[x].son,topy);
for(int v:b[x]){
if(v==dot[x].fa||v==dot[x].son) continue;
dfs2(v,v);
}
return;
}
void pushup(int k,int l,int r){
sum(k)=sum(k*2)+sum(k*2+1);
mx(k)=max(mx(k*2),mx(k*2+1));
}
void build(int k,int l,int r){
if(l==r){
int w=dot[id[l]].w;
tr[k].mx=w;
tr[k].sum=w;
tr[k].val=w;
return;
}
int mid=(l+r)>>1;
build(k*2,l,mid);
build(k*2+1,mid+1,r);
pushup(k,l,r);
}
int longquerymx(int k,int l,int r,int x,int y){
int mx=-1e5;
if(x<=l&&r<=y){
mx=max(mx,mx(k));
return mx;
}
int mid=(l+r)>>1;
if(x<=mid) mx=max(mx,longquerymx(k*2,l,mid,x,y));
if(y>mid) mx=max(mx,longquerymx(k*2+1,mid+1,r,x,y));
return mx;
}
int longquerysum(int k,int l,int r,int x,int y){
int sum=0;
if(x<=l&&r<=y){
return sum(k);
}
int mid=(l+r)>>1;
if(x<=mid) sum+=longquerysum(k*2,l,mid,x,y);
if(y>mid) sum+=longquerysum(k*2+1,mid+1,r,x,y);
return sum;
}
void dotchange(int k,int l,int r,int x,int val){
if(l==r){
tr[k].mx=val;
tr[k].sum=val;
tr[k].val=val;
return;
}
int mid=(l+r)>>1;
if(x<=mid) dotchange(k*2,l,mid,x,val);
else dotchange(k*2+1,mid+1,r,x,val);
pushup(k,l,r);
return;
}
void operat(int x,int y,int op){
if(op) ans=max(ans,longquerymx(1,1,n,x,y));
else ans+=longquerysum(1,1,n,x,y);
}
void jump(int x,int y,int op){
while(top(x)!=top(y)){
if(dep(top(x))<dep(top(y))) swap(x,y);
operat(dfn(top(x)),dfn(x),op);
x=fa(top(x));
}
if(dep(x)<dep(y)) swap(x,y);
operat(dfn(y),dfn(x),op);
return;
}
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
b[u].push_back(v);
b[v].push_back(u);
}
for(int i=1;i<=n;i++){
scanf("%d",&dot[i].w);
}
dfs1(1,0);
dfs2(1,1);
build(1,1,n);
scanf("%d",&q);
for(int i=1;i<=q;i++){
char c[10];
int u,t;
scanf("%s",c);
scanf("%d%d",&u,&t);
if(c[1]=='H') dotchange(1,1,n,dot[u].dfn,t);//以旧编号索引新编号
else if(c[1]=='M'){
ans=-1e5;//题目有负值
jump(u,t,1);
printf("%d\n",ans);
}
else{
ans=0;
jump(u,t,0);
printf("%d\n",ans);
}
}
}
T2:
线段树上维护颜色段数量总会吧
注意树刨跳重链时,从左往右跳它的序列是反着的,需要先反转再合并
注意细节处理,还有一定要注意在主函数里写初始化代码,已经连续两次犯这个问题导致RE了
点击查看代码
#include<bits/stdc++.h>
#define dfn(x) dot[x].dfn
#define siz(x) dot[x].siz
#define top(x) dot[x].top
#define fa(x) dot[x].fa
#define son(x) dot[x].son
#define dep(x) dot[x].dep
#define lc(x) tr[x].lc
#define rc(x) tr[x].rc
#define num(x) tr[x].num
using namespace std;
const int N=1e5+5;
struct Node{
int dfn,siz,c,top,fa,son,dep;
}dot[N];
struct tree{
int lc,rc,num;
void clear(){
lc=rc=num=0;
}
}tr[N*4];
int cnt,col,n,m;
int id[N],tag[4*N];
vector<int>b[N];
tree ans1,ans2,ans;
void dfs1(int x,int f){
fa(x)=f;
dep(x)=dep(f)+1;
int mson=0,nson=0;
for(int v:b[x]){
if(v==f) continue;
dfs1(v,x);
if(siz(v)>nson){
nson=siz(v);
mson=v;
}
siz(x)+=siz(v);
}
son(x)=mson;
siz(x)++;
return;
}
void dfs2(int x,int topy){
top(x)=topy;
dfn(x)=++cnt;
id[cnt]=x;
if(!son(x)) return;
dfs2(son(x),topy);
for(int v:b[x]){
if(v==fa(x)||v==son(x)) continue;
dfs2(v,v);
}
return;
}
tree flip(tree x){
return (tree){x.rc,x.lc,x.num};
}
tree pushup(tree x,tree y){
if(!x.num) return y;
else if(!y.num) return x;
return (tree){x.lc,y.rc,x.num+y.num-(x.rc==y.lc)};
}
void build(int k,int l,int r){
if(l==r){
int c=dot[id[l]].c;
lc(k)=rc(k)=c;
num(k)=1;
return;
}
int mid=(l+r)>>1;
build(k*2,l,mid);
build(k*2+1,mid+1,r);
tr[k]=pushup(tr[k*2],tr[k*2+1]);
}
void change(int k,int c){
lc(k)=rc(k)=c;
num(k)=1;
tag[k]=c;
}
void pushdown(int k){
if(!tag[k]) return;
change(k*2,tag[k]);
change(k*2+1,tag[k]);
tag[k]=0;
}
void longchange(int k,int l,int r,int x,int y,int c){
if(x<=l&&r<=y){
change(k,c);
return;
}
pushdown(k);
int mid=(l+r)>>1;
if(x<=mid) longchange(k*2,l,mid,x,y,c);
if(y>mid) longchange(k*2+1,mid+1,r,x,y,c);
tr[k]=pushup(tr[k*2],tr[k*2+1]);
}
tree longquery(int k,int l,int r,int x,int y){
if(x<=l&&r<=y){
return tr[k];
}
pushdown(k);
int mid=(l+r)>>1;
tree cnt;
cnt.clear();
if(x<=mid) cnt=longquery(k*2,l,mid,x,y);
if(y>mid) cnt=pushup(cnt,longquery(k*2+1,mid+1,r,x,y));
return cnt;
}
void operat(int x,int y,int op,int side){
// printf("%d %d %d\n",x,y,op);
if(!op){
longchange(1,1,n,x,y,col);
}
else{
if(!side){
ans1=pushup(ans1,flip(longquery(1,1,n,x,y)));
}
else{
ans2=pushup(longquery(1,1,n,x,y),ans2);
}
}
}
void jump(int x,int y,int op){
while(top(x)!=top(y)){
if(dep(top(x))>=dep(top(y))){
operat(dfn(top(x)),dfn(x),op,0);
x=fa(top(x));
}
else{
operat(dfn(top(y)),dfn(y),op,1);
y=fa(top(y));
}
}
if(dep(x)>dep(y)){
operat(dfn(y),dfn(x),op,0);
}
else{
operat(dfn(x),dfn(y),op,1);
}
ans=pushup(ans1,ans2);
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
scanf("%d",&dot[i].c);
}
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
b[u].push_back(v);
b[v].push_back(u);
}
dfs1(1,0);
dfs2(1,1);
build(1,1,n);
// for(int i=1;i<=n;i++){
// printf("%d %d %d\n",dfn(i),top(i),son(i));
// }
for(int i=1;i<=m;i++){
char c[10];
scanf("%s",c);
int a,b;
if(c[0]=='C'){
scanf("%d%d%d",&a,&b,&col);
jump(a,b,0);
}
else{
ans.clear();
ans1.clear();
ans2.clear();
scanf("%d%d",&a,&b);
jump(a,b,1);
printf("%d\n",ans.num);
}
}
}
T3:
对于每一种宗教开一颗线段树,动态开点维护,
由于动态开点原来不会,现学习了一下
均摊下来是 \(O(nlogn)\)
一次宗教改变的事件即为先删除原宗教上的点,后加入新的点
一下是调代码时出现的错误,deepseek指出后就过了,可以参考一下
- 删除操作未正确回收节点
原删除函数 delate 未传递节点指针的引用,导致:
无法更新父节点的子指针
无法回收空节点
未清除城市对应的线段树节点指针
- 删除逻辑错误
原代码判断 if(sum(now)==val) 来决定是否回收节点是错误的。正确逻辑是:
删除叶子节点后回收
向上回溯时,若节点变为空(无左右子树)则回收
- 未更新根节点引用
在宗教修改操作中,删除旧宗教节点时未传递根节点引用,导致根节点无法被置空
点击查看代码
#include<bits/stdc++.h>
#define fa(x) dot[x].fa
#define son(x) dot[x].son
#define siz(x) dot[x].siz
#define dfn(x) dot[x].dfn
#define top(x) dot[x].top
#define dep(x) dot[x].dep
#define ls(x) tr[x].ls
#define rs(x) tr[x].rs
#define sum(x) tr[x].sum
#define mx(x) tr[x].mx
using namespace std;
const int N=2e6+5,M=1e5+5;
struct Node{
int dfn,top,fa,son,siz,dep,w,c,now;
}dot[N];
struct tree{
int ls,rs,sum,mx,w;
void clear(){
ls=rs=sum=mx=w=0;
}
}tr[N];
vector<int>b[M];
int dotcnt,stktop,n,q,ans;
int st[M],stk[N],id[M];
void dfs1(int x,int f){
fa(x)=f;
dep(x)=dep(f)+1;
int mson=0,nson=0;
for(int v:b[x]){
if(v==f) continue;
dfs1(v,x);
if(nson<siz(v)){
nson=siz(v);
mson=v;
}
siz(x)+=siz(v);
}
siz(x)++;
son(x)=mson;
}
void dfs2(int x,int topy){
top(x)=topy;
dfn(x)=++dotcnt;
id[dotcnt]=x;
if(!son(x)) return;
dfs2(son(x),topy);
for(int v:b[x]){
if(v==fa(x)||v==son(x)) continue;
dfs2(v,v);
}
}
void initstk(){
for(int i=1;i<=N-5;i++) stk[i]=i;
stktop=N-5;
}
void putstk(int x){
stk[++stktop]=x;
}
int outstk(){
return stk[stktop--];
}
int newnode(){
int now;
now=outstk(),tr[now].clear();
return now;
}
void delnode(int now){
putstk(now);
tr[now].clear();
}
void update(int now){
int sum=0,mx=0;
if(ls(now)){
sum+=sum(ls(now));
mx=max(mx,mx(ls(now)));
}
if(rs(now)){
sum+=sum(rs(now));
mx=max(mx,mx(rs(now)));
}
sum(now)=sum;
mx(now)=mx;
}
void insert(int &now,int l,int r,int x,int val){
if(!now) now=newnode();
if(l==r){
ls(now)=rs(now)=0;
sum(now)=val;
mx(now)=val;
tr[now].w=val;
dot[id[x]].now=now;
return;
}
int mid=(l+r)>>1;
if(x<=mid) insert(ls(now),l,mid,x,val);
else insert(rs(now),mid+1,r,x,val);
update(now);
}
void delate(int &now,int l,int r,int x,int val){
if(l==r){
delnode(now),now=0,dot[id[x]].now=0;
return;
}
int mid=(l+r)>>1;
if(x<=mid) delate(ls(now),l,mid,x,val);
else delate(rs(now),mid+1,r,x,val);
update(now);
if(!ls(now)&&!rs(now)) delnode(now),now=0;
}
int querymx(int now,int l,int r,int x,int y){
if(!now) return 0;
if(x<=l&&r<=y){
return mx(now);
}
int mid=(l+r)>>1;
int mx=0;
if(x<=mid) mx=max(mx,querymx(ls(now),l,mid,x,y));
if(y>mid) mx=max(mx,querymx(rs(now),mid+1,r,x,y));
return mx;
}
int querysum(int now,int l,int r,int x,int y){
if(!now) return 0;
if(x<=l&&r<=y){
return sum(now);
}
int mid=(l+r)>>1;
int sum=0;
if(x<=mid) sum+=querysum(ls(now),l,mid,x,y);
if(y>mid) sum+=querysum(rs(now),mid+1,r,x,y);
return sum;
}
void operat(int x,int y,int op,int c){
if(op==0){
ans+=querysum(st[c],1,n,x,y);
// printf("%d %d %d %d\n",x,y,c,ans);
}
else{
ans=max(ans,querymx(st[c],1,n,x,y));
}
}
void jump(int x,int y,int op,int c){
while(top(x)!=top(y)){
if(dep(top(x))<dep(top(y))) swap(x,y);
operat(dfn(top(x)),dfn(x),op,c);
x=fa(top(x));
}
if(dep(x)<dep(y)) swap(x,y);
operat(dfn(y),dfn(x),op,c);
}
int main(){
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++){
scanf("%d%d",&dot[i].w,&dot[i].c);
}
for(int i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
b[x].push_back(y);
b[y].push_back(x);
}
initstk();
dfs1(1,0);
dfs2(1,1);
for(int i=1;i<=n;i++){
insert(st[dot[i].c],1,n,dfn(i),dot[i].w);
// printf("%d %d\n",dfn(i),i);
// printf("%d %d %d %d %d\n",tr[dot[i].now].w,dot[i].w,dot[i].c,dot[i].now,dot[i].w);
}
for(int i=1;i<=q;i++){
char s[10];
int x,c;
scanf("%s",s);
scanf("%d%d",&x,&c);
ans=0;
if(s[1]=='C'){
delate(st[dot[x].c],1,n,dfn(x),dot[x].w);
dot[x].c=c;
insert(st[c],1,n,dfn(x),dot[x].w);
}
else if(s[1]=='W'){
dot[x].w=c;
insert(st[dot[x].c],1,n,dfn(x),c);
}
else if(s[1]=='S'){
jump(x,c,0,dot[x].c);
printf("%d\n",ans);
}
else{
jump(x,c,1,dot[x].c);
printf("%d\n",ans);
}
}
}

浙公网安备 33010602011771号