【bzoj3697】采药人的路径 树的点分治

题目描述

给出一棵 $n$ 个点的树,每条边的边权为1或0。求有多少点对 $(i,j)$ ,使得:$i$ 到 $j$ 的简单路径上存在点 $k$ (异于 $i$ 和 $j$ ),使得 $i$ 到 $k$ 的简单路径上0和1数目相等,$j$ 到 $k$ 的简单路径上0和1数目也相等。

输入

第1行包含一个整数N。
接下来N-1行,每行包含三个整数a_i、b_i和t_i,表示这条路上药材的类型。

输出

输出符合采药人要求的路径数目。

样例输入

7
1 2 0
3 1 1
2 4 0
5 2 0
6 3 1
5 7 1

样例输出

1


题解

树的点分治

求满足条件的路径数目,可以考虑点分治,每次求过根节点的方案数,再递归处理子树。

设 $f[x][0/1]$ 表示 $1$ 的数目比 $0$ 的数目多 $x$ ,且 否/是 有满足条件的 $k$ 节点的点数。
设 $g[x][0/1]$ 表示 $0$ 的数目比 $1$ 的数目多 $x$ ,且 否/是 有满足条件的 $k$ 节点的点数。

然后相应的答案贡献就是 $f[x][0]·g[x][1]+f[x][1]·g[x][0]+f[x][1]·g[x][1]$ 。

注意特殊处理数目相等的情况,以及分治中心作为路径端点的情况。

dfs整棵子树,求出答案,减去两端在同一子树内的答案,递归子树。

时间复杂度 $O(n\log n)$ 

#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 100010
using namespace std;
int head[N] , to[N << 1] , val[N << 1] , next[N << 1] , cnt , vis[N] , si[N] , ms[N] , sum , root , md;
long long f[N][2] , g[N][2] , ans;
inline void add(int x , int y , int z)
{
	to[++cnt] = y , val[cnt] = z , next[cnt] = head[x] , head[x] = cnt;
}
void getroot(int x , int fa)
{
	int i;
	si[x] = 1 , ms[x] = 0;
	for(i = head[x] ; i ; i = next[i])
		if(!vis[to[i]] && to[i] != fa)
			getroot(to[i] , x) , si[x] += si[to[i]] , ms[x] = max(ms[x] , si[to[i]]);
	ms[x] = max(ms[x] , sum - si[x]);
	if(ms[x] < ms[root]) root = x;
}
void calc(int x , int fa , int now , int cnt)
{
	int i;
	if(now == 0)
	{
		if(cnt >= 2) ans ++ ;
		cnt ++ ;
	}
	for(i = head[x] ; i ; i = next[i])
		if(!vis[to[i]] && to[i] != fa)
			calc(to[i] , x , now + 2 * val[i] - 1 , cnt);
}
void dfs(int x , int fa , int now , int l , int r)
{
	int i;
	if(now >= l && now <= r)
	{
		if(now >= 0) f[now][1] ++ ;
		else g[-now][1] ++ ;
	}
	else
	{
		if(now >= 0) f[now][0] ++ ;
		else g[-now][0] ++ ;
	}
	l = min(l , now) , r = max(r , now) , md = max(md , max(-l , r));
	for(i = head[x] ; i ; i = next[i])
		if(!vis[to[i]] && to[i] != fa)
			dfs(to[i] , x , now + val[i] * 2 - 1 , l , r);
}
void solve(int x)
{
	int i , j;
	vis[x] = 1 , md = 0 , calc(x , 0 , 0 , 0);
	dfs(x , 0 , 0 , 1 , -1) , ans += f[0][1] * (f[0][1] - 1) / 2 , f[0][0] = f[0][1] = 0;
	for(i = 1 ; i <= md ; i ++ ) ans += f[i][0] * g[i][1] + f[i][1] * g[i][0] + f[i][1] * g[i][1] , f[i][0] = f[i][1] = g[i][0] = g[i][1] = 0;
	for(i = head[x] ; i ; i = next[i])
	{
		if(!vis[to[i]])
		{
			md = 0 , dfs(to[i] , 0 , 2 * val[i] - 1 , 0 , 0) , ans -= f[0][1] * (f[0][1] - 1) / 2 , f[0][0] = f[0][1] = 0;
			for(j = 0 ; j <= md ; j ++ ) ans -= f[j][0] * g[j][1] + f[j][1] * g[j][0] + f[j][1] * g[j][1] , f[j][0] = f[j][1] = g[j][0] = g[j][1] = 0;
			sum = si[to[i]] , root = 0 , getroot(to[i] , 0) , solve(root);
		}
	}
}
int main()
{
	int n , i , x , y , z;
	scanf("%d" , &n);
	for(i = 1 ; i < n ; i ++ ) scanf("%d%d%d" , &x , &y , &z) , add(x , y , z) , add(y , x , z);
	sum = n , root = 0 , ms[0] = n , getroot(1 , 0) , solve(root);
	printf("%lld\n" , ans);
	return 0;
}

 

 

posted @ 2018-03-20 21:00  GXZlegend  阅读(303)  评论(0编辑  收藏