bzoj3626: [LNOI2014]LCA 树上差分

Description

给出一个n个节点的有根树(编号为0到n-1,根节点为0)。一个点的深度定义为这个节点到根的距离+1。
设dep[i]表示点i的深度,LCA(i,j)表示i与j的最近公共祖先。
有q次询问,每次询问给出l r z,求sigma_{l<=i<=r}dep[LCA(i,z)]。
(即,求在[l,r]区间内的每个节点i与z的最近公共祖先的深度之和)

Input

第一行2个整数n q。
接下来n-1行,分别表示点1到点n-1的父节点编号。
接下来q行,每行3个整数l r z。

Output

输出q行,每行表示一个询问的答案。每个答案对201314取模输出

题解

题解参考http://hzwer.com/3891.html

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;
const int maxn=1e6+10; 
const int mod=201314;
int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
int n,m;
int head[maxn],cnt;
struct edge{int to,nxt;}e[maxn];
void add(int x,int y){
    e[++cnt]=(edge){y,head[x]};
    head[x]=cnt;
}
int fa[maxn],son[maxn],size[maxn],top[maxn];
int rnk[maxn],id[maxn],tot,dep[maxn];
struct ques{
    int p,f,num;
    bool tag;
    bool operator <(const ques&a)const{
        return p<a.p;
    }
}q[maxn*2];
int ans1[maxn],ans2[maxn];
void dfs1(int x,int f,int d){
    fa[x]=f;dep[x]=d;size[x]=1;
    for(int i=head[x];i;i=e[i].nxt){
        int u=e[i].to;
        if(u==f)continue;
        dfs1(u,x,d+1);
        size[x]+=size[u]; 
        if(!son[x]||size[son[x]]<size[u])son[x]=u;
    }
}
void dfs2(int x,int t){
    top[x]=t;
    id[x]=++tot;
    rnk[tot]=x;
    if(!son[x])return ;
    dfs2(son[x],t);
    for(int i=head[x];i;i=e[i].nxt){
        int u=e[i].to;
        if(u==fa[x]||u==son[x])continue;
        dfs2(u,u);
    }
}
int tag[maxn*4],sum[maxn*4];

void pushdown(int id,int l,int r){
    if(tag[id]){
        int tl=id<<1,tr=id<<1|1;
        int mid=(l+r)>>1;
        (sum[tl]+=tag[id]*(mid-l+1)%mod)%=mod;
        (sum[tr]+=tag[id]*(r-mid)%mod)%=mod;
        (tag[tl]+=tag[id])%=mod;
        (tag[tr]+=tag[id])%=mod;
        tag[id]=0;
    }
}
void pushup(int id){
    int l=id<<1,r=id<<1|1;
    sum[id]=(sum[l]+sum[r])%mod;
}
void update(int id,int l,int r,int tl,int tr){
    if(tl<=l&&r<=tr){
        tag[id]++;(sum[id]+=r-l+1)%=mod;
        return ;
    }
    int mid=(l+r)>>1;
    pushdown(id,l,r);
    if(tr<=mid)update(id<<1,l,mid,tl,tr);
    else if(tl>mid)update(id<<1|1,mid+1,r,tl,tr);
    else{
        update(id<<1,l,mid,tl,tr);
        update(id<<1|1,mid+1,r,tl,tr);
    }
    pushup(id);
}
int query(int id,int l,int r,int tl,int tr){
    if(tl<=l&&r<=tr)return sum[id];
    pushdown(id,l,r);
    int mid=(l+r)>>1;
    if(tr<=mid)return query(id<<1,l,mid,tl,tr);
    else if(tl>mid)return query(id<<1|1,mid+1,r,tl,tr);
    else return (query(id<<1,l,mid,tl,tr)+query(id<<1|1,mid+1,r,tl,tr))%mod;
}
void tree_update(int x){
    int y=1;
    int fx=top[x],fy=top[y];
    while(fx!=fy){
        if(dep[fx]>dep[fy]){
            update(1,1,n,id[fx],id[x]);
            x=fa[fx];
        }
        else{
            update(1,1,n,id[fy],id[y]);
            y=fa[fy];
        }
        fx=top[x];fy=top[y];
    }
    if(dep[x]>dep[y])update(1,1,n,id[y],id[x]);
    else update(1,1,n,id[x],id[y]);
}
int tree_query(int x){
    int y=1;
    int fx=top[x],fy=top[y];
    int ans=0;
    while(fx!=fy){
        if(dep[fx]>dep[fy]){
            (ans+=query(1,1,n,id[fx],id[x]))%=mod;
            x=fa[fx];
        }
        else{
            (ans+=query(1,1,n,id[fy],id[y]))%=mod;
            y=fa[fy];
        }
        fx=top[x];fy=top[y];
    }
    if(dep[x]>dep[y])(ans+=query(1,1,n,id[y],id[x]))%=mod;
    else (ans+=query(1,1,n,id[x],id[y]))%=mod;
    return ans;
}
int main(){
    n=read();m=read();
    for(int i=2;i<=n;i++){
        int x=read()+1;
        add(x,i);
    }
    dfs1(1,0,0);dfs2(1,1);
    for(int i=1;i<=m;i++){
        int l=read()+1,r=read()+1,z=read()+1;
        q[i*2-1].p=l-1;q[i*2-1].f=z;q[i*2-1].tag=0;q[i*2-1].num=i;
        q[i*2].p=r;q[i*2].f=z;q[i*2].tag=1;q[i*2].num=i;
    }
    sort(q+1,q+2*m+1);
    int l=1;
    for(int i=1;i<=2*m;i++){
        while(l<=q[i].p){tree_update(l);l++;}
        if(q[i].tag==0)ans1[q[i].num]=tree_query(q[i].f);
        else ans2[q[i].num]=tree_query(q[i].f);
    }
    for(int i=1;i<=m;i++)printf("%d\n",(ans2[i]-ans1[i]+mod)%mod);
    return 0;
}
posted @ 2018-10-01 18:39  南城ㄱ  阅读(125)  评论(0编辑  收藏  举报