洛谷 P5901 [IOI2009] regions

洛谷传送门

经典根号分治题。

思路

显然有两种暴力:

  1. 对于每个地区为 \(r_1\) 的结点,计算它的子树内有多少个地区为 \(r_2\) 的结点。

  2. 对于每个地区为 \(r_2\) 的结点,计算它到祖先的链上有多少个地区为 \(r_1\) 的结点。

\(cnt_i\) 为第 \(i\) 个地区的数量。若 \(cnt_{r_2} \ge \sqrt{n}\),意味着这样的 \(r_2\) 数量 \(\le \sqrt{n}\)。因此可以将所有地区为 \(r_1\) 的点直接挂上询问然后跑暴力 \(1\),易知这样的询问次数不超过 \(n \sqrt{n}\)。计算子树内权值出现的数量可以差分做。

\(cnt_{r_1} < \sqrt{n}\),那么直接跑暴力 \(2\),dfs 时维护当前结点的所有祖先的地区出现次数即可。

时间复杂度 \(O(n \sqrt{n})\)

代码

code
/*

p_b_p_b txdy
AThousandSuns txdy
Wu_Ren txdy
Appleblue17 txdy

*/

#include <bits/stdc++.h>
#define pb push_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;

const int maxn = 200100;
const int maxm = 25010;

int n, m, q, a[maxn], cnt[maxn], ans[maxn], fa[maxn];
int head[maxn], len, ccnt[maxn];
vector<pii> vc1[maxm], vc2[maxm];
vector<int> col[maxm];

struct edge {
	int to, next;
} edges[maxn << 1];

int find(int x) {
	return fa[x] == x ? x : fa[x] = find(fa[x]);
}

void merge(int x, int y) {
	x = find(x);
	y = find(y);
	if (x != y) {
		fa[x] = y;
	}
}

void add_edge(int u, int v) {
	edges[++len].to = v;
	edges[len].next = head[u];
	head[u] = len;
}

void dfs(int u, int f) {
	for (pii p : vc1[a[u]]) {
		ans[p.scd] += ccnt[p.fst];
	}
	for (pii p : vc2[a[u]]) {
		ans[p.scd] -= cnt[p.fst];
	}
	++ccnt[a[u]];
	for (int i = head[u]; i; i = edges[i].next) {
		int v = edges[i].to;
		if (v == f) {
			continue;
		}
		dfs(v, u);
	}
	--ccnt[a[u]];
	for (pii p : vc2[a[u]]) {
		ans[p.scd] += cnt[p.fst];
	}
	++cnt[a[u]];
}

void solve() {
	scanf("%d%d%d%d", &n, &m, &q, &a[1]);
	int B = sqrt(n);
	for (int i = 2, p; i <= n; ++i) {
		scanf("%d%d", &p, &a[i]);
		add_edge(p, i);
	}
	for (int i = 1; i <= n; ++i) {
		++cnt[a[i]];
		col[a[i]].pb(i);
	}
	for (int i = 1; i <= q; ++i) {
		fa[i] = i;
	}
	map<pii, int> mp;
	for (int i = 1; i <= q; ++i) {
		int x, y;
		scanf("%d%d", &x, &y);
		if (mp.find(make_pair(x, y)) != mp.end()) {
			merge(i, mp[make_pair(x, y)]);
			continue;
		}
		if (cnt[y] <= B) {
			vc1[y].pb(make_pair(x, i));
		} else {
			vc2[x].pb(make_pair(y, i));
		}
	}
	memset(cnt, 0, sizeof(cnt));
	dfs(1, -1);
	for (int i = 1; i <= q; ++i) {
		printf("%d\n", ans[find(i)]);
	}
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}
posted @ 2022-07-19 14:11  zltzlt  阅读(62)  评论(0)    收藏  举报