Tree(POJ-1741)

题目描述:

Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.

输入

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.

输出

For each test case output the answer on a single line.

思路:

比较普遍的一道点分治题,考虑每一棵树,以重心为根,预处理出每个点的深度,再把每个点扔到一个数组中进行线性计算,算出满足条件的所有点对,方法可以是将其排序,用两个指针从两边往中间推着计算。

不过这时候会有小问题,会多算一种情况,就是他们的LCA不是重心的情况,这时候就需要采用容斥原理的思想,在每个重心的子树中计算一遍上述的操作(注意加上重心到根节点的距离),再在答案中对应地减去,便能得到最终答案!

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
bool mem1;
const int N=100005;
struct Graph{
	int tot,to[N<<1],nxt[N<<1],len[N<<1],head[N];
	void add(int x,int y,int z){tot++;to[tot]=y;nxt[tot]=head[x];len[tot]=z;head[x]=tot;}
	void clear(){tot=0;memset(head,-1,sizeof(head));}
}G;
bool vis[N];
int ans,sz[N],mx[N],t_sz,center;
int arr[N],dep[N];
int n,k;
bool mem2;
void make_dep(int x,int f){
	arr[++arr[0]]=dep[x];
	for(int i=G.head[x];i!=-1;i=G.nxt[i]){
		int v=G.to[i];
		if(v==f||vis[v])continue;
		dep[v]=dep[x]+G.len[i];
		make_dep(v,x);
	}
}
void get_center(int x,int f){
	sz[x]=1,mx[x]=0;
	for(int i=G.head[x];i!=-1;i=G.nxt[i]){
		int v=G.to[i];
		if(v==f||vis[v])continue;
		get_center(v,x);
		sz[x]+=sz[v];
		mx[x]=max(mx[x],sz[v]);
	}
	mx[x]=max(mx[x],t_sz-sz[x]);
	if(!center||mx[x]<mx[center])center=x;
}
int calc(int x,int dis){
	dep[x]=dis,arr[0]=0;
	make_dep(x,0);
	sort(arr+1,arr+arr[0]+1);
	int j=arr[0],ret=0;
	for(int i=1;i<=arr[0];i++){
		while(j>i&&arr[i]+arr[j]>k)j--;
		ret+=max(0,j-i);
	}
	return ret;
}
void solve(int x){
	vis[x]=1;
	ans+=calc(x,0);
	for(int i=G.head[x];i!=-1;i=G.nxt[i]){
		int v=G.to[i];
		if(vis[v])continue;
		ans-=calc(v,G.len[i]);
		center=0,t_sz=sz[v];
		get_center(v,x);
		solve(center);
	}
}
int main(){
	while(scanf("%d%d",&n,&k)==2){
		if(!n&&!k)break;
		G.clear();
		memset(vis,0,sizeof vis);
		for(int i=1;i<n;i++){
			int x,y,z;
			scanf("%d%d%d",&x,&y,&z);
			G.add(x,y,z),G.add(y,x,z);
		}
		center=0,t_sz=n,ans=0;
		get_center(1,0);
		solve(center);
		printf("%d\n",ans);
	}
	return 0;
}
posted @ 2019-03-01 21:38  Hëinz  阅读(215)  评论(0)    收藏  举报