SPOJ COT2 Count on a tree 2

树上的链询问, 可以在括号序上跑莫队。

括号序就是这样一个队列 Q:对整棵树从根开始 dfs, 进入一个点就 Q.push_back(+x), 退出一个点就 Q.push_back(-x)。

对于一个询问 (u,v), 假设 +u < +v。

若 u 是 v 的祖先, 考虑在 Q 上用 [+u,+v] 表示这段路, 则之前 dfs 的时候从 u 开始访问 v 所在的子树之前若访问了其它子树 x, 则 +u 和 +v 之间必然有 +x 和 -x。

反之, 考虑在 Q 上用 [-u,+v] 表示这段路, 少了 lca(u,v), 要补上。

#include <bits/stdc++.h>
using namespace std;
const int N = 4e5 + 23, M = 4e5 + 23;
int n, m, B, a[N], b[N];
int ecnt, hd[N], nt[N*2+1], vr[N*2+1];
void ad (int u, int v) { nt[++ ecnt] = hd[u], hd[u] = ecnt, vr[ecnt] = v; }

int dep[N], siz[N], fa[N], son[N], tp[N], dfntot, in[N], out[N], Q[N * 2 + 23];
void dfs1 (int x, int F, int D) {
	dep[x] = D, siz[x] = 1, fa[x] = F;
	for (int i = hd[x]; i; i = nt[i]) {
		int y = vr[i];
		if (y == F) continue;
		dfs1 (y, x, D + 1);
		siz[x] += siz[y];
		if (siz[y] > siz[son[x]]) son[x] = y;
	}
}
void dfs2 (int x, int T) {
	in[x] = ++ dfntot; Q[dfntot] = x;
	tp[x] = T;
	if (son[x]) dfs2 (son[x], T);
	for (int i = hd[x]; i; i = nt[i]) {
		int y = vr[i];
		if (y == fa[x] || y == son[x]) continue;
		dfs2 (y, y);
	}
	out[x] = ++ dfntot; Q[dfntot] = x;
}
int lca (int x, int y) {
	while (tp[x] ^ tp[y]) dep[tp[x]] > dep[tp[y]] ? x = fa[tp[x]] : y = fa[tp[y]];
	return dep[x] > dep[y] ? y : x;
}

struct ask {	int id, l, r, els, belong; } q[M];
bool cmp (ask x, ask y) { return x.belong ^ y.belong ? x.l < y.l : x.r < y.r; }

int vis[N], ans[N], ccnt[N], tot;
void upd (int x) {
	vis[x] ^= 1;
	if (!vis[x] && (-- ccnt[a[x]]) == 0) -- tot;
	if (vis[x] && (ccnt[a[x]] ++) == 0) ++ tot; 
}

int main()
{
	scanf ("%d%d", &n, &m); B = sqrt (m);
	for (int i = 1; i <= n; ++ i) scanf ("%d", & b[i]), a[i] = b[i];
	sort (b + 1, b + 1 + n);
	int nanachi = unique (b + 1, b + 1 + n) - b - 1;
	for (int i = 1; i <= n; ++ i) a[i] = lower_bound (b + 1, b + nanachi + 1, a[i]) - b;
	for (int i = 1, x, y; i < n; ++ i) scanf ("%d%d", & x, & y), ad (x, y), ad (y, x);
	dfs1 (1, 0, 1), dfs2 (1, 1);
	for (int i = 1; i <= m; ++ i) {
		int x, y, A; scanf ("%d%d", & x, & y); A = lca (x, y);
		q[i].id = i;
		if (in[y] < in[x]) swap (x, y);
		if (x == A || y == A)
			q[i].l = in[x], q[i].r = in[y], q[i].els = 0;
		else
			q[i].l = out[x], q[i].r = in[y], q[i].els = A;
		q[i].belong = q[i].l / B;
	}
	sort (q + 1, q + m + 1, cmp);
	for (int i = 1, L = 1, R = 0; i <= m; ++ i) {
		while (R < q[i].r) upd (Q[++ R]);
		while (R > q[i].r) upd (Q[R --]);
		while (L < q[i].l) upd (Q[L ++]);
		while (L > q[i].l) upd (Q[-- L]);
		if (q[i].els) upd (q[i].els);
		ans[q[i].id] = tot;
		if (q[i].els) upd (q[i].els);
	}
	for (int i = 1; i <= m; ++ i) cout << ans[i] << '\n';
	return 0;
}
posted @ 2021-03-03 17:41  xwmwr  阅读(50)  评论(0编辑  收藏  举报