*题解:P11364 [NOIP2024] 树上查询

原题链接

困难题。

解析

考虑处理出形如 \((u,v,d)\) 的三元组表示编号从 \(u\)\(v\) 的点的 LCA 深度为 \(d\),且区间 \([u,v]\) 是极长的。怎么处理呢?肯定要利用子树信息,我们尝试进一步合并子树已经合并出的区间,并将合并出的更大区间记录。由于合并 \(n-1\) 次就变成 \([1,n]\) 了,所以最终处理出的三元组数量是 \(O(n)\) 的。运用启发式合并的思想,对于点 \(x\) 遍历轻儿子所在子树中的结点并尝试扩展其所在区间,看 \(u-1\)\(v + 1\) 是否还在 \(x\) 的子树内,是则合并,合并的过程可以使用并查集来维护。

时间复杂度 \(O(n\log^2 n)\)。(但是貌似有些地方说是单 \(\log\) 的,然后又没有同时写路径压缩和按秩合并,搞不太懂。作者写 \(\log^2\) 是因为懒得写按秩合并。)

然后处理询问,需要讨论询问区间 \([l,r]\) 与三元组区间 \([u,v]\) 的相交情况:

\(u \le l \le r \le v\),则我们希望统计 \((u,v,d)\)\(d\) 的最大值。考虑扫描线,从左往右扫,遇到 \(u\) 就在 \(v\) 位置上做出 \(d\) 的贡献,遇到 \(l\) 就对区间 \([r,n]\) 取最大值,可以使用线段树来维护。

\(l \le u \le r \le v\),则此时有额外限制 \(r - u + 1 \ge k\)。由于每次询问的 \(k\) 不一样,考虑对 \(k\) 从大到小做扫描线。对于每个 \(k\),加入所有区间长度为 \(k\) 的三元组,在 \(u\) 处做 \(d\) 的贡献,对于询问参数为 \(k\) 的询问,只需查询 \([l,r - k + 1]\) 中的最大值,同样可以使用线段树来维护。

\(u \le l \le v \le r\),则此时有额外限制 \(v - l + 1 \ge k\)。处理方法同上,最终是在 \(v\) 处作贡献,查询 \([l + k - 1,r]\) 中的最大值。

\(l \le u \le v \le r\),则可以发现这种情况被包含在第二第三种情况中了,且不影响正确性。

这部分的时间复杂度是 \(O(n \log n)\)

代码

#include <bits/stdc++.h>
#define ls(p) ((p) << 1)
#define rs(p) (((p) << 1) | 1)
#define mid ((l + r) >> 1)
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const int N = 5e5 + 5,M = 5e5,mod = 998244353;
vector<int> t[N];
int dep[N],dfn[N],siz[N],son[N];
int res[N];
int cnt = 0;
struct Seg{
	int l = 0,r = 0,d = 0;
	Seg(){}
	Seg(int _l,int _r,int _d) : l(_l),r(_r),d(_d){}
	friend bool operator < (Seg a,Seg b){
		return a.l < b.l;
	}
};
vector<Seg> s,len[N],lf[N];
struct Query{
	int l,r,k,id;
};
vector<Query> q1[N],q2[N];
struct SMT{
	int mx[N << 2];
	SMT(){
		memset(mx,0,sizeof(mx));
	}
	void push_up(int p){
		mx[p] = max(mx[ls(p)],mx[rs(p)]);
	}
	void modi(int p,int l,int r,int k,int x){
		if(l > k || r < k) return;
		if(l == r){
			mx[p] = max(mx[p],x);
			return;
		}
		modi(ls(p),l,mid,k,x),modi(rs(p),mid + 1,r,k,x);
		push_up(p);
	}
	int ask(int p,int l,int r,int L,int R){
		if(l > R || r < L) return 0;
		if(l >= L && r <= R){
			return mx[p];
		}
		return max(ask(ls(p),l,mid,L,R),ask(rs(p),mid + 1,r,L,R));
	}
}t1,t2,t3;
void dfs(int x,int fa){
	dep[x] = dep[fa] + 1;
	dfn[x] = ++cnt;
	siz[x] = 1;
	for(int nx : t[x])if(nx != fa){
		dfs(nx,x);
		siz[x] += siz[nx]; 
		if(siz[son[x]] < siz[nx]) son[x] = nx;
	}
}
int ft[N],bk[N];
int find(int x,int f[]){
	if(f[x] != x) f[x] = find(f[x],f);
	return f[x];
}

void merge(int x,int u){
	int fnt = find(x,ft),bck = find(x,bk);
	bool f = false;
	while(dfn[fnt - 1] >= dfn[u] && dfn[fnt - 1] <= dfn[u] + siz[u] - 1){
		ft[fnt] = find(fnt - 1,ft);
		fnt = ft[fnt];
		bk[fnt] = bck;
		f = true;
	}
	while(dfn[bck + 1] >= dfn[u] && dfn[bck + 1] <= dfn[u] + siz[u] - 1){
		bk[bck] = find(bck + 1,bk);
		bck = bk[bck];	
		ft[bck] = fnt;
		f = true;
	}
	if(f){
		s.push_back({fnt,bck,dep[u]});
		len[bck - fnt + 1].push_back({fnt,bck,dep[u]});
		lf[fnt].push_back({fnt,bck,dep[u]});
	}
}
void sol(int x,int fa,int u){
	for(int nx : t[x])if(nx != fa){
		sol(nx,x,u);
	}
	merge(x,u);
}
void dfs2(int x,int fa){
	for(int nx : t[x])if(nx != fa){
		dfs2(nx,x);
	}
	for(int nx : t[x])if(nx != fa && nx != son[x]){
		sol(nx,x,x);
	}
	merge(x,x);
	len[1].push_back({x,x,dep[x]});
	lf[x].push_back({x,x,dep[x]});
}
int read(){
	int a = 1,x = 0;
	char ch = getchar();
	while(ch > '9' || ch < '0'){
		if(ch == '-') a = -1;
		ch = getchar();
	}
	while(ch >= '0' && ch <= '9'){
		x = x * 10 + ch - '0';
		ch = getchar(); 
	}
	return a * x;
} 
int main(){
	ios::sync_with_stdio(false);
	cin.tie(0);
//	freopen("in.txt","r",stdin);
//	freopen("out.txt","w",stdout);
	int n,q;
	n = read();
	for(int i=1;i<=n;i++){
		ft[i] = bk[i] = i;
	} 
	for(int i=1;i<n;i++){
		int u,v;
		u = read(),v = read();
		t[u].push_back(v);
		t[v].push_back(u);
	} 
	dfs(1,0);
	dfs2(1,0);
	q = read();
	for(int i=1;i<=q;i++){
		int l,r,k;
		l = read(),r = read(),k = read();
		q1[l].push_back((Query){l,r,k,i});
		q2[k].push_back((Query){l,r,k,i});
	}
	for(int i=1;i<=n;i++){
		for(Seg x : lf[i]){
			t1.modi(1,1,n,x.r,x.d);
		}
		for(Query x : q1[i]){
			res[x.id] = max(res[x.id],t1.ask(1,1,n,x.r,n));
		}
	}
	for(int i=n;i>=1;i--){
		for(Seg x : len[i]){
			t2.modi(1,1,n,x.l,x.d);
			t3.modi(1,1,n,x.r,x.d);
		}
		for(Query x : q2[i]){
			res[x.id] = max({res[x.id],t2.ask(1,1,n,x.l,x.r - i + 1),t3.ask(1,1,n,x.l + i - 1,x.r)});
		}
	}
	for(int i=1;i<=q;i++){
		cout<<res[i]<<'\n';
	}
	return 0;
}
posted @ 2025-11-16 23:03  yutar  阅读(22)  评论(0)    收藏  举报