BZOJ 4719: [Noip2016]天天爱跑步

题解咕掉了

CODE

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int MAXN = 300005;
const int LOG = 19;
int n, m, ans[MAXN], w[MAXN], fir[MAXN], to[MAXN<<1], nxt[MAXN<<1], cnt;
inline void link(int u, int v) {
	to[++cnt] = v, nxt[cnt] = fir[u], fir[u] = cnt;
	to[++cnt] = u, nxt[cnt] = fir[v], fir[v] = cnt;
}
int f[MAXN][LOG], dep[MAXN], S[MAXN], T[MAXN], LCA[MAXN], Len[MAXN];
void dfs(int u, int ff) {
	dep[u] = dep[f[u][0] = ff] + 1;
	for(int i = fir[u], v; i; i = nxt[i])
		if((v=to[i]) != ff) dfs(v, u);
}
inline int lca(int x, int y) {
	if(dep[x] > dep[y]) swap(x, y);
	for(int i = LOG-1; ~i; --i)
		if(dep[f[y][i]] >= dep[x]) y = f[y][i];
	if(x == y) return x;
	for(int i = LOG-1; ~i; --i)
		if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
	return f[x][0];
}
int t[2][MAXN<<1];
vector<pair<int,int> >vec[MAXN];
void getans(int u, int ff) {
	int pre[2] = {t[0][w[u] + dep[u]],t[1][w[u] - dep[u] + MAXN]};
	for(int i = fir[u], v; i; i = nxt[i])
		if((v=to[i]) != ff) getans(v, u);
	for(int i = vec[u].size()-1; ~i; --i)
		if(~vec[u][i].first) ++t[vec[u][i].first][vec[u][i].second];
	ans[u] += t[0][w[u] + dep[u]] - pre[0];
	ans[u] += t[1][w[u] - dep[u] + MAXN] - pre[1];
	for(int i = vec[u].size()-1; ~i; --i)
		if(vec[u][i].first == -1) {
			int id = vec[u][i].second;
			--t[0][dep[S[id]]];
			--t[1][Len[id]-dep[T[id]]+MAXN];
		}
}
int main () {
	scanf("%d%d", &n, &m);
	for(int i = 1, x, y; i < n; ++i)
		scanf("%d%d", &x, &y), link(x, y);
	for(int i = 1; i <= n; ++i) scanf("%d", &w[i]);
	dfs(1, 0);
	for(int j = 1; j < LOG; ++j)
		for(int i = 1; i <= n; ++i)
			f[i][j] = f[f[i][j-1]][j-1];
	for(int i = 1; i <= m; ++i) {
		scanf("%d%d", &S[i], &T[i]), LCA[i] = lca(S[i], T[i]);
		if(dep[S[i]]-dep[LCA[i]] == w[LCA[i]]) --ans[LCA[i]];
		Len[i] = dep[S[i]] + dep[T[i]] - (dep[LCA[i]]<<1);
		vec[S[i]].push_back(pair<int,int>(0, dep[S[i]]));
		vec[T[i]].push_back(pair<int,int>(1, Len[i]-dep[T[i]]+MAXN));
		vec[LCA[i]].push_back(pair<int,int>(-1, i));
	}
	getans(1, 0);
	for(int i = 1; i <= n; ++i)
		printf("%d%c", ans[i], " \n"[i==n]);
}
posted @ 2019-12-14 14:50  _Ark  阅读(109)  评论(0编辑  收藏  举报