CF1904E Tree Queries 题解
假如每次询问的点 是同一个点怎么做?我们考虑将这棵无根树看作以 为根的树,设根 的深度为 ,那么每个点到 的距离就是这个点的深度 。删掉一个点 本质上就是删掉以 为根的这棵子树,询问相当于问全局最大值。每个删除的点显然在 DFS 序上是一个区间,于是可以求出没有被删除的所有区间。于是我们容易在 的时间复杂度内单次回答。
但是问题在于, 可能不同。比较显然的做法是离线换根。类似换根 DP。我们设树根为 。现在我们要将根从 变成其某个儿子 。考虑深度的变化。比较套路的是以 为子树的点深度减 ,其他点加 。这都可以用线段树维护 DFS 序上的区间加维护。
考虑第二部分,即对于当前换到的根 ,删掉点 的本质。如果 不在 的路径上,即 不是 的祖先,那么删掉的就是 的子树,否则 为 的祖先。那么 只能到 下面往 走一步的那个点的子树内,即相当于保留一个区间,就是删掉 减去这个区间。于是整个题就做完了。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <vector>
#include <array>
using namespace std;
const int N = 2e5 + 5;
vector<int> G[N];
int n, q;
int dep[N], id[N], sz[N], idx;
int fa[N][21];
struct Query
{
vector<int> ver;
int id;
Query()
{
id = 0;
ver.clear();
}
Query(vector<int>& v, int id) : ver(v), id(id) {}
};
vector<Query> qry[N];
class SegmentTree
{
public:
struct Node
{
int l, r, tag, maxn;
}tr[N << 2];
void pushup(int u)
{
tr[u].maxn = max(tr[u << 1].maxn, tr[u << 1 | 1].maxn);
}
void pushtag(int u, int v)
{
tr[u].tag += v;
tr[u].maxn += v;
}
void pushdown(int u)
{
if (tr[u].tag)
{
pushtag(u << 1, tr[u].tag);
pushtag(u << 1 | 1, tr[u].tag);
tr[u].tag = 0;
}
}
void build(int u, int l, int r, int* a)
{
tr[u] = { l, r, 0, a[l] };
if (l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid, a);
build(u << 1 | 1, mid + 1, r, a);
pushup(u);
}
void update(int u, int l, int r, int c)
{
if (tr[u].l >= l and tr[u].r <= r)
{
pushtag(u, c);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) update(u << 1, l, r, c);
if (r > mid) update(u << 1 | 1, l, r, c);
pushup(u);
}
int query(int u, int l, int r)
{
if (tr[u].l >= l and tr[u].r <= r) return tr[u].maxn;
pushdown(u);
int res = 0, mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) res = query(u << 1, l, r);
if (r > mid) res = max(res, query(u << 1 | 1, l, r));
return res;
}
}sgt;
int na[N];
array<int, N> ans;
int main()
{
ios::sync_with_stdio(0), cin.tie(0);
cin >> n >> q;
for (int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
G[u].emplace_back(v);
G[v].emplace_back(u);
}
for (int i = 1; i <= q; i++)
{
int u;
cin >> u;
vector<int> v;
int k;
cin >> k;
for (int j = 1; j <= k; j++)
{
int b;
cin >> b;
v.emplace_back(b);
}
qry[u].emplace_back(Query(v, i));
}
auto dfs = [&](auto self, int u, int f)->void
{
fa[u][0] = f;
dep[u] = dep[f] + 1;
sz[u] = 1;
id[u] = ++idx;
for_each(G[u].begin(), G[u].end(), [&](const auto& j) {
if (j != f)
{
self(self, j, u);
sz[u] += sz[j];
}
});
};
dep[0] = -1;
dfs(dfs, 1, 0);
for (int j = 1; j <= 20; j++)
{
for (int i = 1; i <= n; i++) fa[i][j] = fa[fa[i][j - 1]][j - 1];
}
auto kth_anc = [&](int u, int k)
{
int c = 0;
while (k)
{
if (k & 1) u = fa[u][c];
c++;
k >>= 1;
}
return u;
};
for (int i = 1; i <= n; i++)
{
na[id[i]] = dep[i];
}
sgt.build(1, 1, n, na);
auto solve = [&](auto self, int u, int f)->void
{
for (auto& [v, idx] : qry[u])
{
vector<pair<int, int>> segs;
int ans = 0;
int maxr = 0;
for (auto& j : v)
{
if (j == u)
{
::ans[idx] = 0;
goto E;
}
if (id[u] >= id[j] && id[u] < id[j] + sz[j])
{
int dis = dep[u] - dep[j];
int sj = kth_anc(u, dis - 1);
// 只能去 [id[sj], id[sj] + sz[sj])
if (id[sj] != 1) segs.emplace_back(make_pair(1, id[sj] - 1));
if (id[sj] + sz[sj] != n + 1) segs.emplace_back(make_pair(id[sj] + sz[sj], n));
}
else
{
segs.emplace_back(make_pair(id[j], id[j] + sz[j] - 1));
}
}
for (auto& [l, r] : segs) maxr = max(maxr, r);
sort(segs.begin(), segs.end());
if (segs.size() && segs.front().first != 1) ans = sgt.query(1, 1, segs.front().first - 1);
if (segs.size() && maxr != n) ans = max(ans, sgt.query(1, maxr + 1, n));
if (segs.empty()) ans = max(ans, sgt.tr[1].maxn);
maxr = 0;
for (int k = 0; k + 1 < segs.size(); k++)
{
maxr = max(maxr, segs[k].second);
if (segs[k + 1].first > maxr)
{
int l = maxr + 1, r = segs[k + 1].first - 1;
if (l <= r) ans = max(ans, sgt.query(1, l, r));
}
}
::ans[idx] = ans;
E:;
}
for (auto& j : G[u])
{
if (j == f)
{
continue;
}
sgt.update(1, id[j], id[j] + sz[j] - 1, -1);
if (id[j] != 1) sgt.update(1, 1, id[j] - 1, 1);
if (id[j] + sz[j] <= n) sgt.update(1, id[j] + sz[j], n, 1);
self(self, j, u);
sgt.update(1, id[j], id[j] + sz[j] - 1, 1);
if (id[j] != 1) sgt.update(1, 1, id[j] - 1, -1);
if (id[j] + sz[j] <= n) sgt.update(1, id[j] + sz[j], n, -1);
}
};
solve(solve, 1, 0);
for (int i = 1; i <= q; i++)
{
cout << ans[i] << "\n";
}
return 0;
}

浙公网安备 33010602011771号