「SPOJ10707」Count on a tree II

「SPOJ10707」Count on a tree II

传送门
树上莫队板子题。
锻炼基础,没什么好说的。
参考代码:

#include <algorithm>
#include <cstdio>
#include <cmath>
#define rg register
#define file(x) freopen(x".in", "r", stdin), freopen(x".out", "w", stdout)
using namespace std;
template < class T > inline void read(T& s) {
	s = 0; int f = 0; char c = getchar();
	while ('0' > c || c > '9') f |= c == '-', c = getchar();
	while ('0' <= c && c <= '9') s = s * 10 + c - 48, c = getchar();
	s = f ? -s : s;
}

const int _ = 40005, __ = 1e5 + 5;

int tot, head[_]; struct Edge { int ver, nxt; } edge[_ << 1];
inline void Add_edge(int u, int v) { edge[++tot] = (Edge) { v, head[u] }; head[u] = tot; }

int n, q, a[_], X0, X[_];
int fir[_], las[_], vis[_], dep[_], fa[17][_];
int len, ord[_ << 1], m, pos[_ << 1];
int ans, cnt[_], res[__];
struct node { int l, r, lca, id; } t[__];
inline bool cmp(const node& x, const node& y)
{ return pos[x.l] != pos[y.l] ? pos[x.l] < pos[y.l] : ((pos[x.l] & 1) ? x.r < y.r : y.r < x.r); }

inline void dfs(int u, int f) {
	fir[u] = ++len, ord[len] = u;
	dep[u] = dep[f] + 1, fa[0][u] = f;
	for (rg int i = 1; i <= 16; ++i) fa[i][u] = fa[i - 1][fa[i - 1][u]];
	for (rg int i = head[u]; i; i = edge[i].nxt) if (edge[i].ver != f) dfs(edge[i].ver, u);
	las[u] = ++len, ord[len] = u;
}

inline int LCA(int x, int y) {
	if (dep[x] < dep[y]) swap(x, y);
	for (rg int i = 16; ~i; --i) if (dep[fa[i][x]] >= dep[y]) x = fa[i][x];
	if (x == y) return x;
	for (rg int i = 16; ~i; --i) if (fa[i][x] != fa[i][y]) x = fa[i][x], y = fa[i][y];
	return fa[0][x];
}

inline void calc(int x) { vis[x] ? ans -= !--cnt[a[x]] : ans += !cnt[a[x]]++, vis[x] ^= 1; }

int main() {
	read(n), read(q);
	for (rg int i = 1; i <= n; ++i) read(a[i]), X[i] = a[i];
	sort(X + 1, X + n + 1);
	X0 = unique(X + 1, X + n + 1) - X - 1;
	for (rg int i = 1; i <= n; ++i) a[i] = lower_bound(X + 1, X + X0 + 1, a[i]) - X;
	for (rg int x, y, i = 1; i < n; ++i) read(x), read(y), Add_edge(x, y), Add_edge(y, x);
	dfs(1, 0);
	for (rg int x, y, lca, i = 1; i <= q; ++i) {
		read(x), read(y), lca = LCA(x, y);
		if (fir[x] > fir[y]) swap(x, y);
		if (x == lca)
			t[i].l = fir[x], t[i].r = fir[y], t[i].lca = 0;
		else 
			t[i].l = las[x], t[i].r = fir[y], t[i].lca = lca;
		t[i].id = i;
	}
	m = sqrt(1.0 * len);
	for (rg int i = 1; i <= len; ++i) pos[i] = (i - 1) / m + 1;
	sort(t + 1, t + q + 1, cmp);
	for (rg int l = 1, r = 0, i = 1; i <= q; ++i) {
		while (l > t[i].l) calc(ord[--l]);
		while (r < t[i].r) calc(ord[++r]);
		while (l < t[i].l) calc(ord[l++]);
		while (r > t[i].r) calc(ord[r--]);
		if (t[i].lca) calc(t[i].lca);
		res[t[i].id] = ans;
		if (t[i].lca) calc(t[i].lca);
	}
	for (rg int i = 1; i <= q; ++i) printf("%d\n", res[i]);
	return 0;
}
posted @ 2020-01-23 23:32  Sangber  阅读(111)  评论(0编辑  收藏  举报