P11364 [NOIP2024] 树上查询

题目

思路

首先有一个关键结论,\(dep_{lca(l,l+1,...,r)} = \min(dep_{lca(i,i+1)})\)

当然,在 \(l=r\)\(lca\) 就是 \(l\),这个单独特判掉。

否则考虑如何证明,设 \(lca(l,l+1,...,r)\)\(x\),则至少有两个点在不同的子树内,或 \(x\) 属于 \((l,r)\)

我们将不同子树的染一个颜色,然后将 \(x\) 在单独染一个颜色,由于这个序列至少有两种颜色,无论怎么重排,一定存在一个 \(i\) 使得 \(i\)\(i+1\) 两个的颜色不一样

我们设 \(v_i = dep_{lca(i,i+1)}\),再将 \(v_1 = v_n = -1\),找到一个最大的管辖区间 \(l_i,r_i\),使得 \(v_i\)\(l_i,r_i\) 中是最小的,每次询问 \(L,R,k\) 转化为求:

  1. \(R \le l_i ∧ R-k+1 \ge l_i\)

  2. \(L+k-1 \le r_i \le R ∧ k \le r_i-l_i+1\)

中满足任意一个条件的 \(v_i\) 的最大值。

可能有细心的小盆友会说:“万一满足条件的 \(i\) 不在 \(l,r-1\) 范围内怎么办捏”。

其实这是无所谓的,举 \(1\) 的例子,若 \(i\) 满足此条件且不在 \(L,R-1\) 里面,则 \(L,R\) 里面一定有连续 \(k\) 个数值 \(\ge v_i\),所以一定会有更优取法。

第二个读者可以自证。

我们可以直接扫描线解决,先用单调栈求出 \(l_i,r_i\),然后写个线段树维护区间最大值,注意要特判掉 \(k=1\) 的情况,因为在这种情况下 \(lca\) 是固定的,st 表预处理一下即可。

最后一个小细节是所有 \(r_i\) 要加一,因为 \(v_i = dep_{lca(i,i+1)}\),本就是一个区间。

code

#include<bits/stdc++.h>
#define int long long
#define mid ((c[p].l+c[p].r)>>1)
#define ls (p<<1)
#define rs ((p<<1)+1)
using namespace std;
#define getchar() (p1 == p2 && (p2 = (p1 = buf1) + fread(buf1, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf1[1 << 23], *p1 = buf1, *p2 = buf1, ubuf[1 << 23], *u = ubuf;
namespace IO
{
	template<typename T>
	void read(T &_x){_x=0;int _f=1;char ch=getchar();while(!isdigit(ch)) _f=(ch=='-'?-1:_f),ch=getchar();while(isdigit(ch)) _x=_x*10+(ch^48),ch=getchar();_x*=_f;}
	template<typename T,typename... Args>
	void read(T &_x,Args&...others){Read(_x);Read(others...);}
	const int BUF=20000000;char buf[BUF],to,stk[32];int plen;
	#define pc(x) buf[plen++]=x
	#define flush(); fwrite(buf,1,plen,stdout),plen=0;
	template<typename T>inline void print(T x){if(!x){pc(48);return;}if(x<0) x=-x,pc('-');for(;x;x/=10) stk[++to]=48+x%10;while(to) pc(stk[to--]);}
}
using namespace IO;
const int N = 5e5+10;
int n,m,x,y,head[N],fa[N],dfn[N],ans[N],dep[N],cnt,st[20][N],lg[N],k,t[N],l,mx;
struct w
{
	int to,nxt;
}b[N<<1];
struct w1
{
	int l,r,v,id;
}a[N],q[N];
struct w2
{
	int l,r,mx;
}c[N<<2];
void build(int p,int l,int r)
{
	c[p].l = l,c[p].r = r;
	if(l == r)
	{
		c[p].mx = 0;
		return;
	}
	build(ls,l,mid),build(rs,mid+1,r);
	c[p].mx = max(c[ls].mx,c[rs].mx);
}
void change(int p,int x,int k)
{
	if(c[p].l == c[p].r)
	{
		c[p].mx = max(c[p].mx,k);
		return;
	}
	if(x <= mid) change(ls,x,k);
	else change(rs,x,k);
	c[p].mx = max(c[ls].mx,c[rs].mx);
}
int query(int p,int l,int r)
{
	int mx = 0;
	if(l <= c[p].l && c[p].r <= r) return c[p].mx;
	if(l <= mid) mx = query(ls,l,r);
	if(mid < r) mx = max(mx,query(rs,l,r));
	return mx;
}
inline int Min(int x,int y){return (dfn[x] < dfn[y]) ? x : y;}
inline void add(int x,int y)
{
	b[++cnt].nxt = head[x];
	b[cnt].to = y;
	head[x] = cnt;
}
void dfs(int x,int y)
{
	dfn[x] = ++cnt,fa[x] = y; dep[x] = dep[y]+1;
	for(int i = head[x];i;i = b[i].nxt)
		if(b[i].to != y)
			dfs(b[i].to,x);
}
inline bool cmp1(w1 x,w1 y){return x.r > y.r;}
inline bool cmp2(w1 x,w1 y){return x.v > y.v;}
inline bool cmp3(w1 x,w1 y){return x.r-x.l > y.r-y.l;}
inline int lca(int x,int y)
{
	x = dfn[x],y = dfn[y];
	if(x > y) swap(x,y);
	x++; k = lg[y-x+1];
	return Min(st[k][x],st[k][y-(1<<k)+1]);
}
signed main()
{
//	freopen("query2.in","r",stdin);
//	freopen("query.out","w",stdout);
	read(n);
	for(int i = 1;i < n;i++) read(x),read(y),add(x,y),add(y,x);
	for(int i = 2;i <= n;i++) lg[i] = lg[i/2]+1;
	cnt = 0; dfs(1,0);
	for(int i = 1;i <= n;i++) st[0][dfn[i]] = fa[i];
	for(int i = 1;i <= lg[n];i++)
		for(int j = 1;j+(1<<i)-1 <= n;j++)
			st[i][j] = Min(st[i-1][j],st[i-1][j+(1<<(i-1))]);
	read(m);
	for(int i = 1;i < n;i++)
		a[i].v = dep[lca(i,i+1)],a[i].l = 1,a[i].r = n-1;
	cnt = 0; 
	for(int i = 1;i < n;i++)
	{
		while(cnt && a[i].v < a[t[cnt]].v) a[t[cnt]].r = i-1,cnt--;
		t[++cnt] = i; 
	} 
	cnt = 0;
	for(int i = n-1;i >= 1;i--)
	{
		while(cnt && a[i].v < a[t[cnt]].v) a[t[cnt]].l = i+1,cnt--;
		t[++cnt] = i; 
	}
	for(int i = 1;i < n;i++) a[i].r++;
	for(int i = 1;i <= n;i++) st[0][i] = dep[i];
	for(int i = 1;i <= lg[n];i++)
		for(int j = 1;j+(1<<i)-1 <= n;j++)
			st[i][j] = max(st[i-1][j],st[i-1][j+(1<<(i-1))]);
	for(int i = 1;i <= m;i++) 
	{
		read(q[i].l),read(q[i].r),read(q[i].v),q[i].id = i;
		if(q[i].v == 1) 
		{
			k = lg[q[i].r-q[i].l+1]; 
			ans[i] = max(st[k][q[i].l],st[k][q[i].r-(1<<k)+1]);
		}
	}
	sort(a+1,a+n,cmp1); sort(q+1,q+1+m,cmp1); l = 1;
	build(1,1,n);
	for(int i = 1;i <= m;i++)
	{
		if(q[i].v == 1) continue;
		while(l < n && q[i].r <= a[l].r) change(1,a[l].l,a[l].v),l++;
		ans[q[i].id] = max(ans[q[i].id],query(1,1,q[i].r-q[i].v+1));
	} 
	sort(a+1,a+n,cmp3); sort(q+1,q+1+m,cmp2); l = 1;
	build(1,1,n);
	for(int i = 1;i <= m;i++)
	{
		if(q[i].v == 1) continue;
		while(l < n && q[i].v <= a[l].r-a[l].l+1) change(1,a[l].r,a[l].v),l++;
		ans[q[i].id] = max(ans[q[i].id],query(1,q[i].l+q[i].v-1,q[i].r));
	}
	for(int i = 1;i <= m;i++) print(ans[i]),pc('\n');
	flush();
	return 0;
}
/*
首先有一个关键结论,dep_{lca(l~r)} = min(dep_{lca(i~i+1)})(l <= i < r)
O对了,l==r要特判,输出dep_l
其他情况下,考虑画图证明,这里稍微文字描述一下,可以自行画图
设lca(l,r)为x,则至少有两个点在不同的子树内,或x属于(l,r) 
若x属于(l,r)则一定存在一次lca(i~i+1)会得到x
否则,一定存在i和i+1不在同一个子树内,这是显然的吧 
那我们就有了一个nq做法,考虑优化
我们设 v_i = dep_{lca(i~i+1)},同时找到他的存在区间(l_i,r_i),即它在(l_i,r_i)里是最小的
那么每次询问变成了求:
R <= r_i && R-k+1 >= l_i
L+k-1 <= r_i <= R && k <= r_i-l_i+1
中v_i的最大值  
*/
posted @ 2025-01-21 21:27  kkxacj  阅读(93)  评论(1)    收藏  举报