SPOJ COT2 Count on a tree 2
树上的链询问, 可以在括号序上跑莫队。
括号序就是这样一个队列 Q:对整棵树从根开始 dfs, 进入一个点就 Q.push_back(+x), 退出一个点就 Q.push_back(-x)。
对于一个询问 (u,v), 假设 +u < +v。
若 u 是 v 的祖先, 考虑在 Q 上用 [+u,+v] 表示这段路, 则之前 dfs 的时候从 u 开始访问 v 所在的子树之前若访问了其它子树 x, 则 +u 和 +v 之间必然有 +x 和 -x。
反之, 考虑在 Q 上用 [-u,+v] 表示这段路, 少了 lca(u,v), 要补上。
#include <bits/stdc++.h>
using namespace std;
const int N = 4e5 + 23, M = 4e5 + 23;
int n, m, B, a[N], b[N];
int ecnt, hd[N], nt[N*2+1], vr[N*2+1];
void ad (int u, int v) { nt[++ ecnt] = hd[u], hd[u] = ecnt, vr[ecnt] = v; }
int dep[N], siz[N], fa[N], son[N], tp[N], dfntot, in[N], out[N], Q[N * 2 + 23];
void dfs1 (int x, int F, int D) {
dep[x] = D, siz[x] = 1, fa[x] = F;
for (int i = hd[x]; i; i = nt[i]) {
int y = vr[i];
if (y == F) continue;
dfs1 (y, x, D + 1);
siz[x] += siz[y];
if (siz[y] > siz[son[x]]) son[x] = y;
}
}
void dfs2 (int x, int T) {
in[x] = ++ dfntot; Q[dfntot] = x;
tp[x] = T;
if (son[x]) dfs2 (son[x], T);
for (int i = hd[x]; i; i = nt[i]) {
int y = vr[i];
if (y == fa[x] || y == son[x]) continue;
dfs2 (y, y);
}
out[x] = ++ dfntot; Q[dfntot] = x;
}
int lca (int x, int y) {
while (tp[x] ^ tp[y]) dep[tp[x]] > dep[tp[y]] ? x = fa[tp[x]] : y = fa[tp[y]];
return dep[x] > dep[y] ? y : x;
}
struct ask { int id, l, r, els, belong; } q[M];
bool cmp (ask x, ask y) { return x.belong ^ y.belong ? x.l < y.l : x.r < y.r; }
int vis[N], ans[N], ccnt[N], tot;
void upd (int x) {
vis[x] ^= 1;
if (!vis[x] && (-- ccnt[a[x]]) == 0) -- tot;
if (vis[x] && (ccnt[a[x]] ++) == 0) ++ tot;
}
int main()
{
scanf ("%d%d", &n, &m); B = sqrt (m);
for (int i = 1; i <= n; ++ i) scanf ("%d", & b[i]), a[i] = b[i];
sort (b + 1, b + 1 + n);
int nanachi = unique (b + 1, b + 1 + n) - b - 1;
for (int i = 1; i <= n; ++ i) a[i] = lower_bound (b + 1, b + nanachi + 1, a[i]) - b;
for (int i = 1, x, y; i < n; ++ i) scanf ("%d%d", & x, & y), ad (x, y), ad (y, x);
dfs1 (1, 0, 1), dfs2 (1, 1);
for (int i = 1; i <= m; ++ i) {
int x, y, A; scanf ("%d%d", & x, & y); A = lca (x, y);
q[i].id = i;
if (in[y] < in[x]) swap (x, y);
if (x == A || y == A)
q[i].l = in[x], q[i].r = in[y], q[i].els = 0;
else
q[i].l = out[x], q[i].r = in[y], q[i].els = A;
q[i].belong = q[i].l / B;
}
sort (q + 1, q + m + 1, cmp);
for (int i = 1, L = 1, R = 0; i <= m; ++ i) {
while (R < q[i].r) upd (Q[++ R]);
while (R > q[i].r) upd (Q[R --]);
while (L < q[i].l) upd (Q[L ++]);
while (L > q[i].l) upd (Q[-- L]);
if (q[i].els) upd (q[i].els);
ans[q[i].id] = tot;
if (q[i].els) upd (q[i].els);
}
for (int i = 1; i <= m; ++ i) cout << ans[i] << '\n';
return 0;
}