luogu 4004 Hello world! 分块 + 并查集 + 乱搞

其实呢,我也不理解这道题咋做,等以后有时间再研究研究

#include <bits/stdc++.h> 
#define ll long long 
#define maxn 100002  
using namespace std;       
void setIO(string s) {
    string in=s+".in"; 
    freopen(in.c_str(),"r",stdin); 
}  
struct Union { 
    int p[maxn];  
    void init() {
        for(int i=0;i<maxn;++i) p[i]=i;     
    } 
    int find(int x) {
        return p[x]==x?x:p[x]=find(p[x]);   
    }   
}tr; 
int n,edges,m; 
ll val[maxn]; 
int hd[maxn],to[maxn<<1],nex[maxn<<1],fa[21][maxn],nx[400][maxn]; 
int dep[maxn],key[maxn];  
void addedge(int u,int v) {
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; 
} 
void dfs(int u,int ff) {
    dep[u]=dep[ff]+1, fa[0][u]=ff, nx[1][u]=ff, nx[0][u]=u;   
    for(int i=2;i<=m;++i) nx[i][u]=nx[i-1][ff];        
    for(int i=1;i<21;++i) fa[i][u]=fa[i-1][fa[i-1][u]];        
    for(int i=hd[u];i;i=nex[i]) {
        int v=to[i]; 
        if(v^ff) dfs(v, u);    
    }
}
int LCA(int x,int y) { 
    if(dep[x]^dep[y]) { 
        if(dep[x] > dep[y]) swap(x,y);    
        for(int i=20;i>=0;--i) if(dep[fa[i][y]]>=dep[x])  y=fa[i][y];   
    }
    if(x==y) return x; 
    for(int i=20;i>=0;--i) if(fa[i][x] ^ fa[i][y]) x=fa[i][x],y=fa[i][y];
    return fa[0][y];       
}
int up(int x,int k) {
    if(k<=m) return nx[k][x];           
    for(int i=20;i>=0;--i) { 
        if(key[i]<=k) x=fa[i][x], k-=key[i];   
        if(!k) break; 
    }      
    return x;    
}
void modify(int x) {
    if(val[x]==1) return; 
    val[x]=sqrt(val[x]); 
    if(val[x]==1) tr.p[x]=tr.find(fa[0][x]);   
}
int jump(int x,int y,int f,int k) {
    if(dep[y]-dep[f]>=k) return up(y,k); 
    return up(x,dep[x]+dep[y]-(dep[f]<<1)-k); 
}
int get(int x, int k) {
    if (k > m) return up(x, k);
    int y = tr.find(fa[0][x]);
    return up(y, (k - (dep[x] - dep[y]) % k) % k);
}
void update(int x,int y,int k) { 
    int f=LCA(x,y), len=dep[x]+dep[y]-(dep[f]<<1);  
    if(len%k) modify(y),y=jump(x,y,f,len%k),f=LCA(x,y);   
    while(dep[x]>=dep[f]) modify(x),x=get(x,k); 
    while(dep[y]>dep[f]) modify(y),y=get(y,k);    
}
ll query(int x,int y,int k) {
    int f=LCA(x,y),len=dep[x]+dep[y]-(dep[f]<<1);  
    ll res=0;    
    if(len%k) {
        int a=len%k; 
        res+=val[y]; 
        // printf("%d %d\n",dep[x]-dep[y],a);  
        y=jump(x,y,f,len%k);
        // y=up(x,11);     
        f=LCA(x,y);  
    }
    res+=(dep[x]+dep[y]-(dep[f]<<1))/k+1; 
    while(dep[x]>=dep[f]) res+=val[x]-1,x=get(x,k); 
    while(dep[y]>dep[f]) res+=val[y]-1,y=get(y,k); 
    return res;     
}
int main() {
    // setIO("input");       
    scanf("%d",&n),m=233;  
    key[0]=1; 
    for(int i=1;i<=22;++i) key[i]=key[i-1]*2;   
    for(int i=1;i<=n;++i) scanf("%lld",&val[i]); 
    for(int i=1;i<n;++i) {
        int u,v; 
        scanf("%d%d",&u,&v); 
        addedge(u,v),addedge(v,u); 
    }   
    dfs(1,0); 
    for(int i=1;i<=n;++i) {
        if(val[i]==1) tr.p[i]=fa[0][i]; 
        else tr.p[i]=i; 
    }
    int Q; 
    scanf("%d",&Q); 
    for(int i=1;i<=Q;++i) {
        int op,x,y,k; 
        scanf("%d%d%d%d",&op,&x,&y,&k);    
        // printf("%d %d %d %d\n",i,op,x,y); 
        if(op==0) update(x,y,k);   
        else printf("%lld\n",query(x,y,k)); 
    }
    return 0; 
}

  

posted @ 2019-07-28 22:23  EM-LGH  阅读(169)  评论(0编辑  收藏  举报