*题解:P11364 [NOIP2024] 树上查询
困难题。
解析
考虑处理出形如 \((u,v,d)\) 的三元组表示编号从 \(u\) 到 \(v\) 的点的 LCA 深度为 \(d\),且区间 \([u,v]\) 是极长的。怎么处理呢?肯定要利用子树信息,我们尝试进一步合并子树已经合并出的区间,并将合并出的更大区间记录。由于合并 \(n-1\) 次就变成 \([1,n]\) 了,所以最终处理出的三元组数量是 \(O(n)\) 的。运用启发式合并的思想,对于点 \(x\) 遍历轻儿子所在子树中的结点并尝试扩展其所在区间,看 \(u-1\) 和 \(v + 1\) 是否还在 \(x\) 的子树内,是则合并,合并的过程可以使用并查集来维护。
时间复杂度 \(O(n\log^2 n)\)。(但是貌似有些地方说是单 \(\log\) 的,然后又没有同时写路径压缩和按秩合并,搞不太懂。作者写 \(\log^2\) 是因为懒得写按秩合并。)
然后处理询问,需要讨论询问区间 \([l,r]\) 与三元组区间 \([u,v]\) 的相交情况:
若 \(u \le l \le r \le v\),则我们希望统计 \((u,v,d)\) 中 \(d\) 的最大值。考虑扫描线,从左往右扫,遇到 \(u\) 就在 \(v\) 位置上做出 \(d\) 的贡献,遇到 \(l\) 就对区间 \([r,n]\) 取最大值,可以使用线段树来维护。
若 \(l \le u \le r \le v\),则此时有额外限制 \(r - u + 1 \ge k\)。由于每次询问的 \(k\) 不一样,考虑对 \(k\) 从大到小做扫描线。对于每个 \(k\),加入所有区间长度为 \(k\) 的三元组,在 \(u\) 处做 \(d\) 的贡献,对于询问参数为 \(k\) 的询问,只需查询 \([l,r - k + 1]\) 中的最大值,同样可以使用线段树来维护。
若 \(u \le l \le v \le r\),则此时有额外限制 \(v - l + 1 \ge k\)。处理方法同上,最终是在 \(v\) 处作贡献,查询 \([l + k - 1,r]\) 中的最大值。
若 \(l \le u \le v \le r\),则可以发现这种情况被包含在第二第三种情况中了,且不影响正确性。
这部分的时间复杂度是 \(O(n \log n)\)。
代码
#include <bits/stdc++.h>
#define ls(p) ((p) << 1)
#define rs(p) (((p) << 1) | 1)
#define mid ((l + r) >> 1)
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const int N = 5e5 + 5,M = 5e5,mod = 998244353;
vector<int> t[N];
int dep[N],dfn[N],siz[N],son[N];
int res[N];
int cnt = 0;
struct Seg{
int l = 0,r = 0,d = 0;
Seg(){}
Seg(int _l,int _r,int _d) : l(_l),r(_r),d(_d){}
friend bool operator < (Seg a,Seg b){
return a.l < b.l;
}
};
vector<Seg> s,len[N],lf[N];
struct Query{
int l,r,k,id;
};
vector<Query> q1[N],q2[N];
struct SMT{
int mx[N << 2];
SMT(){
memset(mx,0,sizeof(mx));
}
void push_up(int p){
mx[p] = max(mx[ls(p)],mx[rs(p)]);
}
void modi(int p,int l,int r,int k,int x){
if(l > k || r < k) return;
if(l == r){
mx[p] = max(mx[p],x);
return;
}
modi(ls(p),l,mid,k,x),modi(rs(p),mid + 1,r,k,x);
push_up(p);
}
int ask(int p,int l,int r,int L,int R){
if(l > R || r < L) return 0;
if(l >= L && r <= R){
return mx[p];
}
return max(ask(ls(p),l,mid,L,R),ask(rs(p),mid + 1,r,L,R));
}
}t1,t2,t3;
void dfs(int x,int fa){
dep[x] = dep[fa] + 1;
dfn[x] = ++cnt;
siz[x] = 1;
for(int nx : t[x])if(nx != fa){
dfs(nx,x);
siz[x] += siz[nx];
if(siz[son[x]] < siz[nx]) son[x] = nx;
}
}
int ft[N],bk[N];
int find(int x,int f[]){
if(f[x] != x) f[x] = find(f[x],f);
return f[x];
}
void merge(int x,int u){
int fnt = find(x,ft),bck = find(x,bk);
bool f = false;
while(dfn[fnt - 1] >= dfn[u] && dfn[fnt - 1] <= dfn[u] + siz[u] - 1){
ft[fnt] = find(fnt - 1,ft);
fnt = ft[fnt];
bk[fnt] = bck;
f = true;
}
while(dfn[bck + 1] >= dfn[u] && dfn[bck + 1] <= dfn[u] + siz[u] - 1){
bk[bck] = find(bck + 1,bk);
bck = bk[bck];
ft[bck] = fnt;
f = true;
}
if(f){
s.push_back({fnt,bck,dep[u]});
len[bck - fnt + 1].push_back({fnt,bck,dep[u]});
lf[fnt].push_back({fnt,bck,dep[u]});
}
}
void sol(int x,int fa,int u){
for(int nx : t[x])if(nx != fa){
sol(nx,x,u);
}
merge(x,u);
}
void dfs2(int x,int fa){
for(int nx : t[x])if(nx != fa){
dfs2(nx,x);
}
for(int nx : t[x])if(nx != fa && nx != son[x]){
sol(nx,x,x);
}
merge(x,x);
len[1].push_back({x,x,dep[x]});
lf[x].push_back({x,x,dep[x]});
}
int read(){
int a = 1,x = 0;
char ch = getchar();
while(ch > '9' || ch < '0'){
if(ch == '-') a = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9'){
x = x * 10 + ch - '0';
ch = getchar();
}
return a * x;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
// freopen("in.txt","r",stdin);
// freopen("out.txt","w",stdout);
int n,q;
n = read();
for(int i=1;i<=n;i++){
ft[i] = bk[i] = i;
}
for(int i=1;i<n;i++){
int u,v;
u = read(),v = read();
t[u].push_back(v);
t[v].push_back(u);
}
dfs(1,0);
dfs2(1,0);
q = read();
for(int i=1;i<=q;i++){
int l,r,k;
l = read(),r = read(),k = read();
q1[l].push_back((Query){l,r,k,i});
q2[k].push_back((Query){l,r,k,i});
}
for(int i=1;i<=n;i++){
for(Seg x : lf[i]){
t1.modi(1,1,n,x.r,x.d);
}
for(Query x : q1[i]){
res[x.id] = max(res[x.id],t1.ask(1,1,n,x.r,n));
}
}
for(int i=n;i>=1;i--){
for(Seg x : len[i]){
t2.modi(1,1,n,x.l,x.d);
t3.modi(1,1,n,x.r,x.d);
}
for(Query x : q2[i]){
res[x.id] = max({res[x.id],t2.ask(1,1,n,x.l,x.r - i + 1),t3.ask(1,1,n,x.l + i - 1,x.r)});
}
}
for(int i=1;i<=q;i++){
cout<<res[i]<<'\n';
}
return 0;
}

浙公网安备 33010602011771号