某模拟赛C题 树上路径统计 (点分治)

题意

给定一棵有n个节点的无根树,树上的每个点有一个非负整数点权。定义一条路径的价值为路径上的点权和-路径上的点权最大值。 给定参数P,我!=们想知道,有多少不同的树上简单路径,满足它的价值恰好是P的倍数。 注意:单点算作一条路径;u!=v时,(u,v)和(v,u)只算一次。

题解

树上路径统计,解法是点分治。点分的时候求出根到每个点路径最大值和权值和。排一序,然后开个桶,就能计算了。去重就套路的减去没棵子树里面的答案。

CODE

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100005;
typedef long long LL;
LL ans;
int n, mod, fir[MAXN], nxt[MAXN<<1], to[MAXN<<1], cnt, val[MAXN];
inline void link(int x, int y) {
	to[++cnt] = y; nxt[cnt] = fir[x]; fir[x] = cnt;
	to[++cnt] = x; nxt[cnt] = fir[y]; fir[y] = cnt;
}
bool ban[MAXN];
int getsz(int u, int ff) {
	int re = 1;
	for(int v, i = fir[u]; i; i = nxt[i])
		if((v=to[i]) != ff && !ban[v])
			re += getsz(v, u);
	return re;
}
int getrt(int u, int ff, int &rt, int Size) {
	int re = 1; bool can = 1;
	for(int v, tmp, i = fir[u]; i; i = nxt[i])
		if((v=to[i]) != ff && !ban[v]) {
			re += (tmp = getrt(v, u, rt, Size));
			if((tmp<<1) > Size) can = 0;
		}
	if(((Size-re)<<1) > Size) can = 0;
	if(can) rt = u;
	return re;
}
struct node {
	int mx, v;
	inline bool operator <(const node &o)const {
		return mx < o.mx;
	}
}seq[MAXN], vv[MAXN];
int tot;
void dfs(int u, int ff, int mx, int vs) {
	vs = (vs + val[u]) % mod;
	mx = max(mx, val[u]);
	vv[u] = (node){ mx, vs };
	for(int v, i = fir[u]; i; i = nxt[i])
		if((v=to[i]) != ff && !ban[v])
			dfs(v, u, mx, vs);
}
void push(int u, int ff) {
	seq[++tot] = vv[u];
	for(int v, i = fir[u]; i; i = nxt[i])
		if((v=to[i]) != ff && !ban[v])
			push(v, u);
}
int bin[10000005];
LL calc(int rt, int o) {
	tot = 0; push(rt, 0);
	sort(seq + 1, seq + tot + 1);
	LL re = 0;
	for(int i = 1; i <= tot; ++i) {
		re += bin[((seq[i].mx+o-seq[i].v)%mod+mod)%mod];
		++bin[seq[i].v%mod];
	}
	for(int i = 1; i <= tot; ++i) --bin[seq[i].v%mod];
	return re;
}
void solve(int x) {
	dfs(x, 0, 0, 0);
	ans += calc(x, val[x]);
	ban[x] = 1;
	for(int v, i = fir[x]; i; i = nxt[i])
		if(!ban[v=to[i]]) ans -= calc(v, val[x]);
}
void TDC(int x) {
	int Size = getsz(x, 0);
	getrt(x, 0, x, Size);
	solve(x);
	for(int v, i = fir[x]; i; i = nxt[i])
		if(!ban[v=to[i]]) TDC(v);
}
int main () {
	scanf("%d%d", &n, &mod);
	for(int i = 1, x, y; i < n; ++i)
		scanf("%d%d", &x, &y), link(x, y);
	for(int i = 1; i <= n; ++i) scanf("%d", &val[i]);
	TDC(1);
	printf("%lld\n", ans+n);
}
posted @ 2019-12-21 11:18  _Ark  阅读(168)  评论(0编辑  收藏  举报