树链剖分

写在前面

某位菜鸡花了半天的时间打完了树剖,又用了半天的时间di了无数个bug后,终于获得了MLE,RE并存的喜人成绩,最终在 \(ljc\) 大佬的指点下才发现原来是函数void写成int了并且没写返回值 T_T

正题

懒得自己写一篇博客了,就把我当时学习树剖时的一篇写的非常好的博客拿出来欣赏吧
博客原文

几条重要的规则

  • 一个字树内的dfs序连续,线段树维护3,4操作
  • 节点数多的叫做重儿子,重儿子到父亲节点的边叫做重边
  • 每一条链dfs序连续(先遍历重儿子)
  • 跳到同一条重链上深度小的即为最近公共祖先
dfs1():

统计每个节点的深度 \(deep[ ]\) ,每个节点父亲 \(fa[ ]\) ,每个节点子树大小 \(size[ ]\) ,重儿子编号 \(mson[]\)

dfs2():

每个旧节点新编号 \(id[]\) ,每个新节点的旧编号 \(di[]\) ,新编号的值 \(newa[]\) ,每个节点所在重链顶端 \(top[]\)

警示后人

一定要处理好新编号和旧编号的关系,最好采用( \(id\)\(di\) )双重索引,注意你的变量是哪一次dfs()维护的,注意新旧节点的转换

代码

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+5;
int n,m,r,p,cnt,ans;
int dfn[N],fa[N],siz[N],dep[N],son[N],top[N],to[N],a[N];
vector<int>b[N];
void dfs1(int u,int f){
    fa[u]=f;
    siz[u]=1;
    dep[u]=dep[f]+1;
    int mx=0;
    for(int v:b[u]){
        if(v==f)  continue;
        dfs1(v,u);
        siz[u]+=siz[v];
        if(siz[v]>mx)  mx=siz[v],son[u]=v;
    }
}
void dfs2(int u,int tp){
    top[u]=tp;
    dfn[u]=++cnt;
    to[cnt]=u;
    if(son[u])  dfs2(son[u],tp);
    for(int v:b[u]){
        if(v==fa[u]||v==son[u])  continue;
        dfs2(v,v);
    }
}
struct dot{
    int x,add;
}tr[N*4];
struct Tree{
    void build(int k,int l,int r){
        if(l==r){
            tr[k]={a[to[l]],0};
            return;
        }
        int mid=(l+r)>>1;
        build(k*2,l,mid);
        build(k*2+1,mid+1,r);
        tr[k].x=(tr[k*2].x+tr[k*2+1].x)%p;
    }
    void Add(int k,int l,int r,int z){
        tr[k].add+=z;
        tr[k].add%=p;
        tr[k].x+=z*(r-l+1);
        tr[k].x%=p;
    }
    void pushdown(int k,int l,int r,int mid){
        if(!tr[k].add)  return;
        Add(k*2,l,mid,tr[k].add);
        Add(k*2+1,mid+1,r,tr[k].add);
        tr[k].add=0;
    }
    void longchange(int k,int l,int r,int x,int y,int z){
        if(x<=l&&r<=y){
            Add(k,l,r,z);
            return;
        }
        int mid=(l+r)>>1;
        pushdown(k,l,r,mid);
        if(x<=mid)  longchange(k*2,l,mid,x,y,z);
        if(y>mid)  longchange(k*2+1,mid+1,r,x,y,z);
        tr[k].x=(tr[k*2].x+tr[k*2+1].x)%p;
    }
    int longquery(int k,int l,int r,int x,int y){
        if(x<=l&&r<=y){
            return tr[k].x;
        }
        int mid=(l+r)>>1,res=0;
        pushdown(k,l,r,mid);
        if(x<=mid)  res+=longquery(k*2,l,mid,x,y);
        if(y>mid)  res+=longquery(k*2+1,mid+1,r,x,y);
        return res%p;
    }
}tree;
void opert(int l,int r,int op,int z){
    if(op==0)  ans+=tree.longquery(1,1,n,l,r),ans%=p;
    else  tree.longchange(1,1,n,l,r,z);
}
void lca(int x,int y,int op,int z){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])  swap(x,y);
        opert(dfn[top[x]],dfn[x],op,z);
        x=fa[top[x]];
    }
    if(dep[x]<dep[y])  swap(x,y);
    opert(dfn[y],dfn[x],op,z);
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>n>>m>>r>>p;
    for(int i=1;i<=n;i++){
        cin>>a[i];
    }
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        b[u].push_back(v);
        b[v].push_back(u);
    }
    dfs1(r,0);
    dfs2(r,r);
    tree.build(1,1,n);
    for(int i=1;i<=m;i++){
        int op,x,y,z;
        ans=0;
        cin>>op>>x;
        if(op==1){
            cin>>y>>z;
            lca(x,y,1,z);
        }
        else if(op==2){
            cin>>y;
            lca(x,y,0,0);
            cout<<ans%p<<'\n';
        }
        else if(op==3){
            cin>>z;
            tree.longchange(1,1,n,dfn[x],dfn[x]+siz[x]-1,z);
        }
        else{
            ans=tree.longquery(1,1,n,dfn[x],dfn[x]+siz[x]-1);
            cout<<ans%p<<'\n';
        }
    }
}

金牌导航题目

T1:

作为停了3,4个月信竞回来做的第一道题,一道板子打了3个小时,还让deepseek帮了忙,非常不牛

但是代码相较以前做出了一些优化:

点击查看代码
#include<bits/stdc++.h>
#define dep(x) dot[x].dep
#define top(x) dot[x].top
#define fa(x) dot[x].fa
#define dfn(x) dot[x].dfn
#define mx(x) tr[x].mx
#define sum(x) tr[x].sum
#define val(x) tr[x].val
using namespace std;
const int N=1e5+5;
struct Node{
    int dfn,dep,top,son,w,fa,siz;
}dot[N];
struct tree{
    int val,sum=0,mx=-1e5;
}tr[N*4];
vector<int>b[N];
int cnt,ans,n,q;
int id[N];
void dfs1(int x,int f){
    dot[x].dep=dot[f].dep+1;
    dot[x].fa=f;
    int maxn=0,maxd=0;
    for(int v:b[x]){
        if(v==f)  continue;
        dfs1(v,x);
        if(dot[v].siz>maxn){
            maxn=dot[v].siz;
            maxd=v;
        }
        dot[x].siz+=dot[v].siz;
    }
    dot[x].siz++;
    dot[x].son=maxd;
    return;
}
void dfs2(int x,int topy){
    dot[x].dfn=++cnt;//一定要按照树链剖分的顺序排dfn,所以要在dfs2中
    id[cnt]=x;
    dot[x].top=topy;
    if(!dot[x].son)  return;
    dfs2(dot[x].son,topy);
    for(int v:b[x]){
        if(v==dot[x].fa||v==dot[x].son)  continue;
        dfs2(v,v);
    }
    return;
}
void pushup(int k,int l,int r){
    sum(k)=sum(k*2)+sum(k*2+1);
    mx(k)=max(mx(k*2),mx(k*2+1));
}
void build(int k,int l,int r){
    if(l==r){
        int w=dot[id[l]].w;
        tr[k].mx=w;
        tr[k].sum=w;
        tr[k].val=w;
        return;
    }
    int mid=(l+r)>>1;
    build(k*2,l,mid);
    build(k*2+1,mid+1,r);
    pushup(k,l,r);
}
int longquerymx(int k,int l,int r,int x,int y){
    int mx=-1e5;
    if(x<=l&&r<=y){
        mx=max(mx,mx(k));
        return mx;
    }
    int mid=(l+r)>>1;
    if(x<=mid)  mx=max(mx,longquerymx(k*2,l,mid,x,y));
    if(y>mid)  mx=max(mx,longquerymx(k*2+1,mid+1,r,x,y));
    return mx;
}
int longquerysum(int k,int l,int r,int x,int y){
    int sum=0;
    if(x<=l&&r<=y){
        return sum(k);
    }
    int mid=(l+r)>>1;
    if(x<=mid)  sum+=longquerysum(k*2,l,mid,x,y);
    if(y>mid)  sum+=longquerysum(k*2+1,mid+1,r,x,y);
    return sum;
}
void dotchange(int k,int l,int r,int x,int val){
    if(l==r){
        tr[k].mx=val;
        tr[k].sum=val;
        tr[k].val=val;
        return;
    }
    int mid=(l+r)>>1;
    if(x<=mid)  dotchange(k*2,l,mid,x,val);
    else  dotchange(k*2+1,mid+1,r,x,val);
    pushup(k,l,r);
    return;
}
void operat(int x,int y,int op){
    if(op)  ans=max(ans,longquerymx(1,1,n,x,y));
    else  ans+=longquerysum(1,1,n,x,y);
}
void jump(int x,int y,int op){
    while(top(x)!=top(y)){
        if(dep(top(x))<dep(top(y)))  swap(x,y);
        operat(dfn(top(x)),dfn(x),op);
        x=fa(top(x));
    }
    if(dep(x)<dep(y))  swap(x,y);
    operat(dfn(y),dfn(x),op);
    return;
}
int main(){
    scanf("%d",&n);
    for(int i=1;i<n;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        b[u].push_back(v);
        b[v].push_back(u);
    }
    for(int i=1;i<=n;i++){
        scanf("%d",&dot[i].w);
    }
    dfs1(1,0);
    dfs2(1,1);
    build(1,1,n);
    scanf("%d",&q);
    for(int i=1;i<=q;i++){
        char c[10];
        int u,t;
        scanf("%s",c);
        scanf("%d%d",&u,&t);
        if(c[1]=='H')  dotchange(1,1,n,dot[u].dfn,t);//以旧编号索引新编号
        else if(c[1]=='M'){
            ans=-1e5;//题目有负值
            jump(u,t,1);
            printf("%d\n",ans);
        }
        else{
            ans=0;
            jump(u,t,0);
            printf("%d\n",ans);
        }
    }
}

T2:

线段树上维护颜色段数量总会吧

注意树刨跳重链时,从左往右跳它的序列是反着的,需要先反转再合并

注意细节处理,还有一定要注意在主函数里写初始化代码,已经连续两次犯这个问题导致RE了

点击查看代码
#include<bits/stdc++.h>
#define dfn(x) dot[x].dfn
#define siz(x) dot[x].siz
#define top(x) dot[x].top
#define fa(x) dot[x].fa
#define son(x) dot[x].son
#define dep(x) dot[x].dep
#define lc(x) tr[x].lc
#define rc(x) tr[x].rc
#define num(x) tr[x].num
using namespace std;
const int N=1e5+5;
struct Node{
    int dfn,siz,c,top,fa,son,dep;
}dot[N];
struct tree{
    int lc,rc,num;
    void clear(){
        lc=rc=num=0;
    }
}tr[N*4];
int cnt,col,n,m;
int id[N],tag[4*N];
vector<int>b[N];
tree ans1,ans2,ans;
void dfs1(int x,int f){
    fa(x)=f;
    dep(x)=dep(f)+1;
    int mson=0,nson=0;
    for(int v:b[x]){
        if(v==f)  continue;
        dfs1(v,x);
        if(siz(v)>nson){
            nson=siz(v);
            mson=v;
        }
        siz(x)+=siz(v);
    }
    son(x)=mson;
    siz(x)++;
    return;
}
void dfs2(int x,int topy){
    top(x)=topy;
    dfn(x)=++cnt;
    id[cnt]=x;
    if(!son(x))  return;
    dfs2(son(x),topy);
    for(int v:b[x]){
        if(v==fa(x)||v==son(x))  continue;
        dfs2(v,v);
    }
    return;
}
tree flip(tree x){
    return (tree){x.rc,x.lc,x.num};
}
tree pushup(tree x,tree y){
    if(!x.num)  return y;
    else if(!y.num)  return x;
    return (tree){x.lc,y.rc,x.num+y.num-(x.rc==y.lc)};
}
void build(int k,int l,int r){
    if(l==r){
        int c=dot[id[l]].c;
        lc(k)=rc(k)=c;
        num(k)=1;
        return;
    }
    int mid=(l+r)>>1;
    build(k*2,l,mid);
    build(k*2+1,mid+1,r);
    tr[k]=pushup(tr[k*2],tr[k*2+1]);
}
void change(int k,int c){
    lc(k)=rc(k)=c;
    num(k)=1;
    tag[k]=c;
}
void pushdown(int k){
    if(!tag[k])  return;
    change(k*2,tag[k]);
    change(k*2+1,tag[k]);
    tag[k]=0;
}
void longchange(int k,int l,int r,int x,int y,int c){
    if(x<=l&&r<=y){
        change(k,c);
        return;
    }
    pushdown(k);
    int mid=(l+r)>>1;
    if(x<=mid)  longchange(k*2,l,mid,x,y,c);
    if(y>mid)  longchange(k*2+1,mid+1,r,x,y,c);
    tr[k]=pushup(tr[k*2],tr[k*2+1]);
}
tree longquery(int k,int l,int r,int x,int y){
    if(x<=l&&r<=y){
        return tr[k];
    }
    pushdown(k);
    int mid=(l+r)>>1;
    tree cnt;
    cnt.clear();
    if(x<=mid)  cnt=longquery(k*2,l,mid,x,y);
    if(y>mid)  cnt=pushup(cnt,longquery(k*2+1,mid+1,r,x,y));
    return cnt;
}
void operat(int x,int y,int op,int side){
    // printf("%d %d %d\n",x,y,op);
    if(!op){
        longchange(1,1,n,x,y,col);
    }
    else{
        if(!side){
            ans1=pushup(ans1,flip(longquery(1,1,n,x,y)));
        }
        else{
            ans2=pushup(longquery(1,1,n,x,y),ans2);
        }
    }
}
void jump(int x,int y,int op){
    while(top(x)!=top(y)){
        if(dep(top(x))>=dep(top(y))){
            operat(dfn(top(x)),dfn(x),op,0);
            x=fa(top(x));
        }
        else{
            operat(dfn(top(y)),dfn(y),op,1);
            y=fa(top(y));
        }
    }
    if(dep(x)>dep(y)){
        operat(dfn(y),dfn(x),op,0);
    }
    else{
        operat(dfn(x),dfn(y),op,1);
    }
    ans=pushup(ans1,ans2);
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++){
        scanf("%d",&dot[i].c);
    }
    for(int i=1;i<n;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        b[u].push_back(v);
        b[v].push_back(u);
    }
    dfs1(1,0);
    dfs2(1,1);
    build(1,1,n);
    // for(int i=1;i<=n;i++){
    //     printf("%d %d %d\n",dfn(i),top(i),son(i));
    // }
    for(int i=1;i<=m;i++){
        char c[10];
        scanf("%s",c);
        int a,b;
        if(c[0]=='C'){
            scanf("%d%d%d",&a,&b,&col);
            jump(a,b,0);
        }
        else{
            ans.clear();
            ans1.clear();
            ans2.clear();
            scanf("%d%d",&a,&b);
            jump(a,b,1);
            printf("%d\n",ans.num);
        }
    }
}

T3:

对于每一种宗教开一颗线段树,动态开点维护,

由于动态开点原来不会,现学习了一下

均摊下来是 \(O(nlogn)\)

一次宗教改变的事件即为先删除原宗教上的点,后加入新的点

一下是调代码时出现的错误,deepseek指出后就过了,可以参考一下

  1. 删除操作未正确回收节点

原删除函数 delate 未传递节点指针的引用,导致:

无法更新父节点的子指针

无法回收空节点

未清除城市对应的线段树节点指针

  1. 删除逻辑错误

原代码判断 if(sum(now)==val) 来决定是否回收节点是错误的。正确逻辑是:

删除叶子节点后回收

向上回溯时,若节点变为空(无左右子树)则回收

  1. 未更新根节点引用

在宗教修改操作中,删除旧宗教节点时未传递根节点引用,导致根节点无法被置空

点击查看代码
#include<bits/stdc++.h>
#define fa(x) dot[x].fa
#define son(x) dot[x].son
#define siz(x) dot[x].siz
#define dfn(x) dot[x].dfn
#define top(x) dot[x].top
#define dep(x) dot[x].dep
#define ls(x) tr[x].ls
#define rs(x) tr[x].rs
#define sum(x) tr[x].sum
#define mx(x) tr[x].mx
using namespace std;
const int N=2e6+5,M=1e5+5;
struct Node{
    int dfn,top,fa,son,siz,dep,w,c,now;
}dot[N];
struct tree{
    int ls,rs,sum,mx,w;
    void clear(){
        ls=rs=sum=mx=w=0;
    }
}tr[N];
vector<int>b[M];
int dotcnt,stktop,n,q,ans;
int st[M],stk[N],id[M];
void dfs1(int x,int f){
    fa(x)=f;
    dep(x)=dep(f)+1;
    int mson=0,nson=0;
    for(int v:b[x]){
        if(v==f)  continue;
        dfs1(v,x);
        if(nson<siz(v)){
            nson=siz(v);
            mson=v;
        }
        siz(x)+=siz(v);
    }
    siz(x)++;
    son(x)=mson;
}
void dfs2(int x,int topy){
    top(x)=topy;
    dfn(x)=++dotcnt;
    id[dotcnt]=x;
    if(!son(x))  return;
    dfs2(son(x),topy);
    for(int v:b[x]){
        if(v==fa(x)||v==son(x))  continue;
        dfs2(v,v);
    }
}
void initstk(){
    for(int i=1;i<=N-5;i++)  stk[i]=i;
    stktop=N-5;
}
void putstk(int x){
    stk[++stktop]=x;
}
int outstk(){
    return stk[stktop--];
}
int newnode(){
    int now;
    now=outstk(),tr[now].clear();
    return now;
}
void delnode(int now){
    putstk(now);
    tr[now].clear();
}
void update(int now){
    int sum=0,mx=0;
    if(ls(now)){
        sum+=sum(ls(now));
        mx=max(mx,mx(ls(now)));
    }
    if(rs(now)){
        sum+=sum(rs(now));
        mx=max(mx,mx(rs(now)));
    }
    sum(now)=sum;
    mx(now)=mx;
}
void insert(int &now,int l,int r,int x,int val){
    if(!now)  now=newnode();
    if(l==r){
        ls(now)=rs(now)=0;
        sum(now)=val;
        mx(now)=val;
        tr[now].w=val;
        dot[id[x]].now=now;
        return;
    }
    int mid=(l+r)>>1;
    if(x<=mid)  insert(ls(now),l,mid,x,val);
    else  insert(rs(now),mid+1,r,x,val);
    update(now);
}
void delate(int &now,int l,int r,int x,int val){
    if(l==r){
        delnode(now),now=0,dot[id[x]].now=0;
        return;
    }
    int mid=(l+r)>>1;
    if(x<=mid)  delate(ls(now),l,mid,x,val);
    else  delate(rs(now),mid+1,r,x,val);
    update(now);
    if(!ls(now)&&!rs(now))  delnode(now),now=0;
}
int querymx(int now,int l,int r,int x,int y){
    if(!now)  return 0;
    if(x<=l&&r<=y){
        return mx(now);
    }
    int mid=(l+r)>>1;
    int mx=0;
    if(x<=mid)  mx=max(mx,querymx(ls(now),l,mid,x,y));
    if(y>mid)  mx=max(mx,querymx(rs(now),mid+1,r,x,y));
    return mx;
}
int querysum(int now,int l,int r,int x,int y){
    if(!now) return 0;
    if(x<=l&&r<=y){
        return sum(now);
    }
    int mid=(l+r)>>1;
    int sum=0;
    if(x<=mid)  sum+=querysum(ls(now),l,mid,x,y);
    if(y>mid)  sum+=querysum(rs(now),mid+1,r,x,y);
    return sum;
}
void operat(int x,int y,int op,int c){
    if(op==0){
        ans+=querysum(st[c],1,n,x,y);
        // printf("%d %d %d %d\n",x,y,c,ans);
    }
    else{
        ans=max(ans,querymx(st[c],1,n,x,y));
    }
}
void jump(int x,int y,int op,int c){
    while(top(x)!=top(y)){
        if(dep(top(x))<dep(top(y)))  swap(x,y);
        operat(dfn(top(x)),dfn(x),op,c);
        x=fa(top(x));
    }
    if(dep(x)<dep(y))  swap(x,y);
    operat(dfn(y),dfn(x),op,c);
}
int main(){
    scanf("%d%d",&n,&q);
    for(int i=1;i<=n;i++){
        scanf("%d%d",&dot[i].w,&dot[i].c);
    }
    for(int i=1;i<n;i++){
        int x,y;
        scanf("%d%d",&x,&y);
        b[x].push_back(y);
        b[y].push_back(x);
    }
    initstk();
    dfs1(1,0);
    dfs2(1,1);
    for(int i=1;i<=n;i++){
        insert(st[dot[i].c],1,n,dfn(i),dot[i].w);
        // printf("%d %d\n",dfn(i),i);
        // printf("%d %d %d %d %d\n",tr[dot[i].now].w,dot[i].w,dot[i].c,dot[i].now,dot[i].w);
    }
    for(int i=1;i<=q;i++){
        char s[10];
        int x,c;
        scanf("%s",s);
        scanf("%d%d",&x,&c);
        ans=0;
        if(s[1]=='C'){
            delate(st[dot[x].c],1,n,dfn(x),dot[x].w);
            dot[x].c=c;
            insert(st[c],1,n,dfn(x),dot[x].w);
        }
        else if(s[1]=='W'){
            dot[x].w=c;
            insert(st[dot[x].c],1,n,dfn(x),c);
        }
        else if(s[1]=='S'){
            jump(x,c,0,dot[x].c);
            printf("%d\n",ans);
        }
        else{
            jump(x,c,1,dot[x].c);
            printf("%d\n",ans);
        }
    }
}
posted @ 2025-07-21 13:05  daydreamer_zcxnb  阅读(37)  评论(0)    收藏  举报