【题解】Ynoi 2011 成都七中
"保留 \([l,r]\) 编号的节点,找到 \(x\) 所在的连通块"一类问题的常见套路就是在点分树上跳 \(x\) 的祖先,找到最上面的一个满足之间的路径全部存在的祖先 \(p\),然后原问题和 l r p 等价。证明不难。
接着建出点分树后,对于每个 \(p\) 就只需要考虑每个颜色的贡献了:对于颜色 \(c\) 的每个节点 \(t_1,t_2,\cdots,t_m\),找到 \(t_{k}\) 到 \(m\) 路径上点编号的最值,这个最值决定了询问区间 \([l,r]\) 能收到贡献的范围。
没有必要建出点分树,可以考虑将询问离线在 \(p\) 上,然后枚举 \(p\) 处理问题,用树状数组做即可。
时间复杂度 \(O(n\log n)\) 。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair <int, int> pii;
#define fi first
#define se second
#define rez resize
#define pb push_back
#define mkp make_pair
#define Lep(i, l, r) for (int i = l; i < r; ++ i)
#define Rep(i, r, l) for (int i = r; i > l; -- i)
#define lep(i, l, r) for (int i = l; i <= r; ++ i)
#define rep(i, r, l) for (int i = r; i >= l; -- i)
char _c; bool _f; template <class T> inline void IN (T & x) {
x = 0, _f = 0; while (_c = getchar (), ! isdigit (_c)) if (_c == '-') _f = 1;
while (isdigit (_c)) x = x * 10 + _c - '0', _c = getchar (); if (_f) x = -x;
}
template <class T> inline void chkmin (T & x, T y) { if (x > y) x = y; }
template <class T> inline void chkmax (T & x, T y) { if (x < y) x = y; }
const int N = 1e5 + 5;
int n, m, c[N];
vector <int> to[N];
bool vis[N];
int rt, all, siz[N], mxp[N];
struct Node { int t, l, r, id; }; vector <Node> anc[N], sta[N];
void getsz (int u, int pre) {
siz[u] = 1;
for (int v : to[u]) if (v != pre && ! vis[v]) getsz (v, u), siz[u] += siz[v];
}
void getrt (int u, int pre) {
mxp[u] = all - siz[u];
for (int v : to[u]) if (v != pre && ! vis[v]) getrt (v, u), chkmax (mxp[u], siz[v]);
if (mxp[u] < mxp[0]) mxp[0] = mxp[u], rt = u;
}
void addtag (int u, int pre, const int &nrt, int mi, int mx) {
anc[u].pb ((Node) {nrt, mi, mx, 0}), sta[nrt].pb ((Node) {c[u], mi, mx, 0});
for (int v : to[u]) if (v != pre && ! vis[v]) addtag (v, u, nrt, min (mi, v), max (mx, v));
}
void divide (int u) {
vis[u] = true, addtag (u, 0, u, u, u);
for (int v : to[u]) if (! vis[v]) getsz (v, 0), all = siz[v], mxp[0] = all + 1, getrt (v, 0), divide (rt);
}
int res, sum[N];
int lowbit (int x) { return x & (-x); }
void modify (int x, int y) { for (; x <= n; x += lowbit (x)) sum[x] += y; }
int query (int x) { res = 0; for (; x; x -= lowbit (x)) res += sum[x]; return res; }
int lst[N], ans[N];
int main () {
IN (n), IN (m);
lep (i, 1, n) IN (c[i]);
for (int u, v, i = 1; i < n; ++ i) IN (u), IN (v), to[u].pb (v), to[v].pb (u);
getsz (1, 0), all = siz[1], mxp[0] = all + 1, getrt (1, 0), divide (rt);
for (int l, r, x, i = 1; i <= m; ++ i) {
IN (l), IN (r), IN (x);
int top = x;
for (auto tmp : anc[x]) if (l <= tmp.l && tmp.r <= r) { top = tmp.t; break; }
sta[top].pb ((Node) { -top, l, r, i});
}
lep (i, 1, n) {
sort (sta[i].begin (), sta[i].end (), [&](const Node x, const Node y) {
return x.l == y.l ? (x.t > y.t) : (x.l > y.l);
});
for (auto tmp : sta[i]) {
if (tmp.t < 0) ans[tmp.id] = query (tmp.r);
else {
if (lst[tmp.t] > tmp.r) modify (lst[tmp.t], -1), lst[tmp.t] = 0;
if (! lst[tmp.t]) modify (tmp.r, 1), lst[tmp.t] = tmp.r;
}
}
for (auto tmp : sta[i]) if (tmp.t > 0 && lst[tmp.t]) modify (lst[tmp.t], -1), lst[tmp.t] = 0;
}
lep (i, 1, m) printf ("%d\n", ans[i]);
return 0;
}

浙公网安备 33010602011771号