洛谷P6071 Treequery(树链剖分+线段树合并+可持久化)

(在接下来的描述和图片中,用$dis_u$表示根(1号点)到$u$的路径的边权和;$[l,r]$范围内的点用蓝色表示,$p$用红色表示)

我们考虑$lca_{l,l+1,\cdots,r}$(以下记为$lca$,用橙色表示),有以下情况:

1.$lca$为$p$的孩子(或$p$)

 

这表明$[l,r]$均为$p$的孩子(或$p$),那么答案为$dis_{lca}-dis_p$

2.$[l,r]$中既有$p$的孩子(或$p$),又有点不是$p$的孩子(或$p$)

 

那么这两种点到$p$的路径无公共边,答案为0

3.$[l,r]$中没有$p$的孩子

这种情况较为复杂,我们还要找出$p$的祖先中最深的一个点,使它也为$[l,r]$中至少一点的祖先(即$lca_{p,l},lca_{p,l+1},\cdots,lca_{p,r}$中最深的一个,以下记为$fa$,用黄色表示)

它和$lca$的位置关系有两种情况:

3.1.它是$lca$的祖先

 

答案为$dis_p+dis_{lca}-2dis_{fa}$

3.2.它是$lca$的孩子

 

答案为$dis_p-dis_{fa}$

分析完毕。接下来研究如何实现:

求$lca$:树链剖分+线段树。

判断一个点的孩子中是否有$[l,r]$中的点:用线段树合并配合可持久化就可以在线处理。

求$fa$:从$p$开始,沿着重链向上跳,发现这条链上有答案时再在链上二分一下。

时间复杂度$O(nlogn+qlog^2n)$,空间复杂度$O(nlogn)$。

这样就可以愉快地上代码了:

#include<cstdio>
#define For(i,A,B) for(i=(A);i<=(B);++i)
#define Go(u) for(i=G[u];i;i=nxt[i])if((v=to[i])!=f[u])
const int N=200050;
char rB[1<<21],*rS,*rT,wB[(1<<21)+50];
int wp=-1;
inline char gc(){return rS==rT&&(rT=(rS=rB)+fread(rB,1,1<<21,stdin),rS==rT)?EOF:*rS++;}
inline void flush(){fwrite(wB,1,wp+1,stdout);wp=-1;}
inline int rd(){
    char c=gc();
    while(c<48||c>57)c=gc();
    int x=c&15;
    for(c=gc();c>=48&&c<=57;c=gc())x=(x<<3)+(x<<1)+(c&15);
    return x;
}
short buf[15];
inline void wt(int x){
    if(wp>(1<<21))flush();
    short l=-1;
    while(x>9){
        buf[++l]=x%10;
        x/=10;
    }
    wB[++wp]=x|48;
    while(l>=0)wB[++wp]=buf[l--]|48;
    wB[++wp]='\n';
}
int G[N],to[N<<1],w[N<<1],nxt[N<<1],sz,sum[N],f[N],dep[N],dfn[N],siz[N],son[N],top[N],ps[N],dfsc,sl[N<<2],rt[N],lc[N*40],rc[N*40],tot,x,y,n;
inline void Swap(int &a,int &b){int t=a;a=b;b=t;}
inline void adde(int u,int v,int c){
    to[++sz]=v;w[sz]=c;nxt[sz]=G[u];G[u]=sz;
    to[++sz]=u;w[sz]=c;nxt[sz]=G[v];G[v]=sz;
}
//可持久化线段树合并部分
int add(int pre,int L,int R,int x){
    int o=++tot;
    if(L<R){
        int M=L+R>>1;
        if(x<=M){lc[o]=add(lc[pre],L,M,x);rc[o]=rc[pre];}
        else{rc[o]=add(rc[pre],M+1,R,x);lc[o]=lc[pre];}
    }
    return o;
}
int merg(int u,int v,int L,int R){
    if(!u||!v)return u|v;
    int o=++tot;
    if(L<R){
        int M=L+R>>1;
        lc[o]=merg(lc[u],lc[v],L,M);
        rc[o]=merg(rc[u],rc[v],M+1,R);
    }
    return o;
}
int ask(int o,int L,int R){
    if(!o)return 0;
    if(x<=L&&y>=R)return 1;
    int M=L+R>>1;
    if(x<=M&&ask(lc[o],L,M))return 1;
    return y>M&&ask(rc[o],M+1,R);
}
//树链剖分部分
void dfs1(int u,int fa){
    int i,v,maxn=0;
    siz[u]=1;
    dep[u]=dep[f[u]=fa]+1;
    rt[u]=add(0,1,n,u);
    Go(u){
        sum[v]=sum[u]+w[i];
        dfs1(v,u);
        siz[u]+=siz[v];
        if(siz[v]>maxn){son[u]=v;maxn=siz[v];}
        rt[u]=merg(rt[u],rt[v],1,n);
    }
}
void dfs2(int u,int topf){
    top[ps[dfn[u]=++dfsc]=u]=topf;
    if(son[u]){
        int i,v;
        dfs2(son[u],topf);
        Go(u)if(v!=son[u])dfs2(v,v);
    }
}
inline int lca(int u,int v){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])Swap(u,v);
        u=f[top[u]];
    }
    return dep[u]<dep[v]?u:v;
}
//lca线段树部分
void build(int o,int L,int R){
    if(L==R)sl[o]=L;
    else{
        int lc=o<<1,rc=lc|1,M=L+R>>1;
        build(lc,L,M);build(rc,M+1,R);
        sl[o]=lca(sl[lc],sl[rc]);
    }
}
int query(int o,int L,int R){
    if(x<=L&&y>=R)return sl[o];
    int lc=o<<1,rc=lc|1,M=L+R>>1;
    if(x<=M)if(y>M)return lca(query(lc,L,M),query(rc,M+1,R));
    else return query(lc,L,M);
    else return query(rc,M+1,R);
}
inline int getfa(int u){
    while(u){
        if(ask(rt[top[u]],1,n)){
            int l=dfn[top[u]],r=dfn[u],mid;
            while(l<r){
                mid=l+r+1>>1;
                if(ask(rt[ps[mid]],1,n))l=mid;
                else r=mid-1;
            }
            return ps[l];
        }
        u=f[top[u]];
    }
}
int main(){
    int q,i,j,u,v,c,lans=0,p;
    n=rd();q=rd();
    For(i,2,n){
        u=rd();v=rd();c=rd();
        adde(u,v,c);
    }
    dfs1(1,0);
    dfs2(1,1);
    build(1,1,n);
    while(q--){
        p=rd()^lans;x=rd()^lans;y=rd()^lans;
        if(dfn[u=query(1,1,n)]>=dfn[p]&&dfn[u]<=dfn[p]+siz[p]-1)lans=sum[u]-sum[p];  //1.
        else if(ask(rt[p],1,n))lans=0;  //2.
        else if(dfn[v=getfa(p)]>=dfn[u]&&dfn[v]<=dfn[u]+siz[u]-1)lans=sum[p]-sum[v];  //3.2.
        else lans=sum[p]+sum[u]-(sum[v]<<1);  //3.1.
        wt(lans);
    }
    flush();
    return 0;
}
View Code

 

posted @ 2020-02-10 14:41  wangyuchen  阅读(162)  评论(0编辑  收藏  举报