POJ 1987 Distance Statistics

http://poj.org/problem?id=1987

题意:给一棵树,求树上有多少对节点满足距离<=K

思路:点分治,我们考虑把每个距离都存起来,然后排序,一遍扫描计算一下,注意还要减掉自己加自己的方案。而且,我们还要去掉走到同一个子树的方案。复杂度:O(nlog^2n)

#include<cstdio>
#include<cmath>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
int tot,go[1000005],first[1000005],next[1000005];
ll st[1000005],val[1000005];
int sum,son[1000005],root,n,F[1000005],c[1000005];
int pd[1000005],sz,vis[1000005];
ll dis[1000005];
int cnt,K,ans;
int read(){
    int t=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
    while ('0'<=ch&&ch<='9'){t=t*10+ch-'0';ch=getchar();}
    return t*f;
}
void insert(int x,int y,int z){
    tot++;
    go[tot]=y;
    next[tot]=first[x];
    first[x]=tot;
    val[tot]=z;
}
void add(int x,int y,int z){
    insert(x,y,z);insert(y,x,z);
}
void findroot(int x,int fa){
    son[x]=1;F[x]=0;
    for (int i=first[x];i;i=next[i]){
        int pur=go[i];
        if (pur==fa||vis[pur]) continue;
        findroot(pur,x);
        son[x]+=son[pur];
        F[x]=std::max(F[x],son[pur]);
    }
    F[x]=std::max(F[x],sum-son[x]);
    if (F[x]<F[root]) root=x;
}
void bfs(int x){
    int h=1,t=1;c[1]=x;pd[x]=sz;dis[x]=0;
    while (h<=t){
        int now=c[h++];
        for (int i=first[now];i;i=next[i]){
            int pur=go[i];
            if (vis[pur]||pd[pur]==sz) continue;
            pd[pur]=sz;
            dis[pur]=dis[now]+val[i];
            c[++t]=pur;
            st[++cnt]=dis[pur];
        }
    }
    std::sort(st+1,st+1+cnt);
    int j=cnt,res=0,Cnt=0;
    for (int i=1;i<=t;i++){
        while (j>1&&st[i]+st[j]>K) j--;
        if (st[i]+st[j]<=K) res+=j;
        if (st[i]+st[i]<=K) Cnt++;
    }
    res-=Cnt;
    ans+=res/2;
}
int del(int x,int Dis){
    dis[x]=Dis;sz++;
    int h=1,t=1;cnt=1;st[cnt]=Dis;
    pd[x]=sz;c[1]=x;
    while (h<=t){
       int now=c[h++];
       for (int i=first[now];i;i=next[i]){
          int pur=go[i];
          if (pd[pur]==sz||vis[pur]) continue;
          dis[pur]=dis[now]+val[i];
          st[++cnt]=dis[pur];
          pd[pur]=sz;
          c[++t]=pur;
       }
    }
    int j=cnt,res=0,Cnt=0;
    std::sort(st+1,st+1+cnt);
    for (int i=1;i<=t;i++){
       while (j>1&&st[i]+st[j]>K) j--;
       if (st[i]+st[j]<=K) res+=j;
       if (st[i]+st[i]<=K) Cnt++;
    }
    res-=Cnt;
    return res/2;
}
void solve(int x){
    vis[x]=1;++sz;
    cnt=1;st[cnt]=0;
    bfs(x);
    for (int i=first[x];i;i=next[i]){
         int pur=go[i];
         if (vis[pur]) continue;
         ans-=del(pur,val[i]);
    }
    int Cnt=sum;
    for (int i=first[x];i;i=next[i]){
         int pur=go[i];
         if (vis[pur]) continue;
         if (son[pur]>son[x]) sum=Cnt-son[x];
         else sum=son[pur];
         root=0;    
         findroot(pur,x);
         solve(root);
    }
}
int main(){
    int m;
    char s[20];
    scanf("%d%d\n",&n,&m);
    for (int i=1;i<n;i++){
        int x,y,z;
        scanf("%d%d%d",&x,&y,&z);
        add(x,y,z);
        scanf("%s",s+1);
    }
    scanf("%d\n",&K);
    F[0]=0x7fffffff;
    root=0;sum=n;
    findroot(1,0);
    solve(root);
    printf("%d\n",ans);
}

 

posted @ 2016-07-14 19:16  GFY  阅读(267)  评论(0编辑  收藏  举报