【poj1741】Tree 树的点分治

Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.

For each test case output the answer on a single line.

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0

8

#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 10010
using namespace std;
int m , head[N] , to[N << 1] , len[N << 1] , next[N << 1] , cnt , si[N] , deep[N] , root , vis[N] , f[N] , sn , d[N] , tot , ans;
void add(int x , int y , int z)
{
to[++cnt] = y , len[cnt] = z , next[cnt] = head[x] , head[x] = cnt;
}
void getroot(int x , int fa)
{
f[x] = 0 , si[x] = 1;
int i;
for(i = head[x] ; i ; i = next[i])
if(to[i] != fa && !vis[to[i]])
getroot(to[i] , x) , si[x] += si[to[i]] , f[x] = max(f[x] , si[to[i]]);
f[x] = max(f[x] , sn - si[x]);
if(f[root] > f[x]) root = x;
}
void getdeep(int x , int fa)
{
d[++tot] = deep[x];
int i;
for(i = head[x] ; i ; i = next[i])
if(to[i] != fa && !vis[to[i]])
deep[to[i]] = deep[x] + len[i] , getdeep(to[i] , x);
}
int calc(int x)
{
tot = 0 , getdeep(x , 0) , sort(d + 1 , d + tot + 1);
int i = 1 , j = tot , sum = 0;
while(i < j)
{
if(d[i] + d[j] <= m) sum += j - i , i ++ ;
else j -- ;
}
return sum;
}
void dfs(int x)
{
deep[x] = 0 , vis[x] = 1 , ans += calc(x);
int i;
for(i = head[x] ; i ; i = next[i])
if(!vis[to[i]])
deep[to[i]] = len[i] , ans -= calc(to[i]) , sn = si[to[i]] , root = 0 , getroot(to[i] , 0) , dfs(root);
}
int main()
{
int n , i , x , y , z;
while(scanf("%d%d" , &n , &m) && (n || m))
{
memset(vis , 0 , sizeof(vis));
cnt = 0 , ans = 0;
for(i = 1 ; i < n ; i ++ )
scanf("%d%d%d" , &x , &y , &z) , add(x , y , z) , add(y , x , z);
f[0] = 0x7fffffff , sn = n;
root = 0 , getroot(1 , 0) , dfs(root);
printf("%d\n" , ans);
}
return 0;
}

posted @ 2017-03-29 20:56  GXZlegend  阅读(...)  评论(...编辑  收藏