【bzoj3697】采药人的路径 树的点分治

7
1 2 0
3 1 1
2 4 0
5 2 0
6 3 1
5 7 1

1

dfs整棵子树，求出答案，减去两端在同一子树内的答案，递归子树。

#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 100010
using namespace std;
int head[N] , to[N << 1] , val[N << 1] , next[N << 1] , cnt , vis[N] , si[N] , ms[N] , sum , root , md;
long long f[N][2] , g[N][2] , ans;
inline void add(int x , int y , int z)
{
to[++cnt] = y , val[cnt] = z , next[cnt] = head[x] , head[x] = cnt;
}
void getroot(int x , int fa)
{
int i;
si[x] = 1 , ms[x] = 0;
for(i = head[x] ; i ; i = next[i])
if(!vis[to[i]] && to[i] != fa)
getroot(to[i] , x) , si[x] += si[to[i]] , ms[x] = max(ms[x] , si[to[i]]);
ms[x] = max(ms[x] , sum - si[x]);
if(ms[x] < ms[root]) root = x;
}
void calc(int x , int fa , int now , int cnt)
{
int i;
if(now == 0)
{
if(cnt >= 2) ans ++ ;
cnt ++ ;
}
for(i = head[x] ; i ; i = next[i])
if(!vis[to[i]] && to[i] != fa)
calc(to[i] , x , now + 2 * val[i] - 1 , cnt);
}
void dfs(int x , int fa , int now , int l , int r)
{
int i;
if(now >= l && now <= r)
{
if(now >= 0) f[now][1] ++ ;
else g[-now][1] ++ ;
}
else
{
if(now >= 0) f[now][0] ++ ;
else g[-now][0] ++ ;
}
l = min(l , now) , r = max(r , now) , md = max(md , max(-l , r));
for(i = head[x] ; i ; i = next[i])
if(!vis[to[i]] && to[i] != fa)
dfs(to[i] , x , now + val[i] * 2 - 1 , l , r);
}
void solve(int x)
{
int i , j;
vis[x] = 1 , md = 0 , calc(x , 0 , 0 , 0);
dfs(x , 0 , 0 , 1 , -1) , ans += f[0][1] * (f[0][1] - 1) / 2 , f[0][0] = f[0][1] = 0;
for(i = 1 ; i <= md ; i ++ ) ans += f[i][0] * g[i][1] + f[i][1] * g[i][0] + f[i][1] * g[i][1] , f[i][0] = f[i][1] = g[i][0] = g[i][1] = 0;
for(i = head[x] ; i ; i = next[i])
{
if(!vis[to[i]])
{
md = 0 , dfs(to[i] , 0 , 2 * val[i] - 1 , 0 , 0) , ans -= f[0][1] * (f[0][1] - 1) / 2 , f[0][0] = f[0][1] = 0;
for(j = 0 ; j <= md ; j ++ ) ans -= f[j][0] * g[j][1] + f[j][1] * g[j][0] + f[j][1] * g[j][1] , f[j][0] = f[j][1] = g[j][0] = g[j][1] = 0;
sum = si[to[i]] , root = 0 , getroot(to[i] , 0) , solve(root);
}
}
}
int main()
{
int n , i , x , y , z;
scanf("%d" , &n);
for(i = 1 ; i < n ; i ++ ) scanf("%d%d%d" , &x , &y , &z) , add(x , y , z) , add(y , x , z);
sum = n , root = 0 , ms[0] = n , getroot(1 , 0) , solve(root);
printf("%lld\n" , ans);
return 0;
}


posted @ 2018-03-20 21:00  GXZlegend  阅读(303)  评论(0编辑  收藏