# [BZOJ3697]采药人的路径

[BZOJ3697]采药人的路径

7
1 2 0
3 1 1
2 4 0
5 2 0
6 3 1
5 7 1

1

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;

int x = 0, f = 1; char c = getchar();
while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); }
while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); }
return x * f;
}

#define maxn 100010
#define maxm 200010
#define LL long long
int n, m, head[maxn], next[maxm], to[maxm], dist[maxm];
LL ans;

void AddEdge(int a, int b, int c) {
to[++m] = b; dist[m] = c; next[m] = head[a]; head[a] = m;
swap(a, b);
to[++m] = b; dist[m] = c; next[m] = head[a]; head[a] = m;
return ;
}

bool vis[maxn];
int root, size, f[maxn], siz[maxn];
void getroot(int u, int fa) {
siz[u] = 1; f[u] = 0;
for(int e = head[u]; e; e = next[e]) if(to[e] != fa && !vis[to[e]]) {
getroot(to[e], u);
siz[u] += siz[to[e]];
f[u] = max(f[u], siz[to[e]]);
}
f[u] = max(f[u], size - siz[u]);
if(f[root] > f[u]) root = u;
return ;
}
int has[maxn<<1], A[2][maxn<<1], B[2][maxn<<1], mxd, mnd;
void dfs(int u, int fa, int d) {
//	printf("(d)%d(%d) ", d, has[d+n]);
mxd = max(mxd, d); mnd = min(mnd, d);
A[has[d+n]?1:0][d+n]++;
has[d+n]++;
for(int e = head[u]; e; e = next[e]) if(to[e] != fa && !vis[to[e]])
dfs(to[e], u, d + dist[e]);
has[d+n]--;
return ;
}
void solve(int u) {
//	printf("u: %d\n", u);
vis[u] = 1;
bool fir = 1;
int Mxd = -n - 1, Mnd = n + 1;
for(int e = head[u]; e; e = next[e]) if(!vis[to[e]]) {
mxd = -n - 1; mnd = n + 1;
dfs(to[e], u, dist[e]);
Mxd = max(Mxd, mxd); Mnd = min(Mnd, mnd);
if(fir) ;
else {
ans += (LL)A[0][n] * B[0][n];
//			printf("%d ", A[0][n] * B[0][n]);
for(int i = n + mnd; i <= n + mxd; i++) {
int d = i - n;
ans += (LL)A[0][i] * B[1][n-d] + A[1][i] * B[0][n-d] + A[1][i] * B[1][n-d];
//				printf("(%d)%d ", d, A[0][i] * B[1][n-d] + A[1][i] * B[0][n-d] + A[1][i] * B[1][n-d]);
}
}
ans += (LL)A[1][n];
fir = 0;
for(int i = n + mnd; i <= n + mxd; i++)
B[0][i] += A[0][i], B[1][i] += A[1][i], A[0][i] = A[1][i] = 0, has[i] = 0;
//		putchar('\n');
}
for(int i = n + Mnd; i <= n + Mxd; i++) B[0][i] = B[1][i] = 0;
for(int e = head[u]; e; e = next[e]) if(!vis[to[e]]) {
root = 0; f[0] = n + 1; size = siz[u]; getroot(to[e], u);
solve(root);
}
return ;
}

int main() {
for(int i = 1; i < n; i++) {
}

root = 0; f[0] = n + 1; size = n; getroot(1, 0);
solve(root);

printf("%lld\n", ans);

return 0;
}


posted @ 2016-09-18 22:15  xjr01  阅读(82)  评论(0编辑  收藏