zoj3195 联通树上三个点的路径长

输出有个坑,两个月之前就没对,,今天又被坑了一次

求联通树上三个点的路径长度,只要求两两点对的最短路径,加起来除以二即可

#include<iostream>
#include<cstring>
#include<cstdio>
#include<queue>
using namespace std;
#define maxn 50005
#define DEG 20
struct Edge{
    int to,next,w;
}edge[maxn*2];
int head[maxn],tot;
inline void addedge(int u,int v,int w){
    edge[tot].to=v;
    edge[tot].next=head[u];
    edge[tot].w=w;
    head[u]=tot++;
}
int deg[maxn],depth[maxn],fa[maxn][DEG];
int flag[maxn];
void bfs(int root){
    queue<int> que;
    deg[root]=depth[root]=0;
    fa[root][0]=root;
    que.push(root);
    while(!que.empty()){
        int tmp=que.front();que.pop();
        for(int i=1;i<DEG;i++)
            fa[tmp][i]=fa[fa[tmp][i-1]][i-1];
        for(int i=head[tmp];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(v==fa[tmp][0]) continue;
            deg[v]=deg[tmp]+1;depth[v]=depth[tmp]+edge[i].w;
            fa[v][0]=tmp;
            que.push(v);
        }
    }
}
int query(int u,int v){
    if(deg[u]>deg[v]) swap(u,v);
    int hu=deg[u],hv=deg[v],tu=u,tv=v;
    for(int det=hv-hu,i=0;det;det>>=1,i++)
        if(det&1) tv=fa[tv][i];
    if(tu==tv) return tu;
    for(int i=DEG-1;i>=0;i--){
        if(fa[tu][i]==fa[tv][i])continue;
        tu=fa[tu][i];tv=fa[tv][i];
    }
    return fa[tu][0];
}
void init(){
    tot=0;
    memset(head,-1,sizeof head);
    memset(depth,0,sizeof depth);
    memset(deg,0,sizeof deg);
} 
int main(){
    int n,q,u,v,w;
    int flagg=0;
    while(~scanf("%d",&n)){
        init();
        for(int i=1;i<n;i++){
            scanf("%d%d%d",&u,&v,&w);
            addedge(u,v,w);addedge(v,u,w);
            flag[v]=1;
        }
        int root;
        for(int i=0;i<n;i++) if(!flag[i]){root=i;break;}
        bfs(root);
        
        if(flagg) putchar('\n');
        else flagg=1;
        scanf("%d",&q);
        int a,b,c;
        while(q--){
            scanf("%d%d%d",&a,&b,&c);
            int tmp1=query(a,b);
            int dis1=depth[a]+depth[b]-depth[tmp1]*2;
            
            int tmp2=query(b,c);
            int dis2=depth[b]+depth[c]-depth[tmp2]*2;
            
            int tmp3=query(a,c);
            int dis3=depth[a]+depth[c]-depth[tmp3]*2;
            printf("%d\n",(dis1+dis2+dis3)/2);
        }
    }
}

 

posted on 2018-11-25 20:22  zsben  阅读(126)  评论(0编辑  收藏  举报

导航