[Tricks-00005][NOIp2024]树上查询 思维方式还是要数形结合!
题目链接。
有一个经典结论是,在 \(l<r\) 的时候,\(dep_{\operatorname{LCA}(l,l+1,\dots,r)}=\min\limits_{i=l}^{r-1}dep_{\operatorname{LCA}(i,i+1)}\),证明也十分容易。
特判掉 \(k=1\) 的特殊情况后,问题则可以转化成:有一个序列 \(d_i=dep_{\operatorname{LCA}(i,i+1)}\),求 \(\max\limits_{i=l}^{r-k+1}(\min\limits_{j=i}^{i+k-2}d_j)\)。这样把树的限制完全去掉了,转化成了一个序列上的问题,后面再咋弄都是简单序列 DS,就很开心了。
直接地去入手,把最小值那个位置拎出来,设为 \(i\),则我只要哦找到一个包含 \(i\) 且以 \(i\) 为最小值点的区间,长度不小于 \(k-1\) 且包含于 \([l,r-1]\) 中。
用代数语言刻画出来:设 \(u_i\) 为 \(i\) 左边第一个 \(d_j<d_i\) 的点 \(j+1\),\(v_i\) 为 \(i\) 右边第一个 \(d_j<d_i\) 的点 \(j-1\)。则要求 \(|[u_i,v_i]\cap [l,r-1]|\geq k-1\)。这里有个小问题,如果交集不包含 \(i\) 不过仍然满足这个条件,怎么算?不过你会发现此时取 \(i\) 一定不如交集中的点优,所以并不需要在意这个 \(i\)。
因此我们只需要找到所有的满足 \(\min(r-1,v_i)-\max(l,u_i)+1\geq k-1\) 的所有 \(i\),把这些位置的 \(d_i\) 取个最大值即可。于是你就会非常容易写的 \(O(nq)\) 了!接下来考虑优化。
注意到形如 \(\min\geq\) 的条件可以直接拆开,于是我们就可以列出如下四个限制:
第一个条件是自动满足的,于是只需要满足:
这是个三维偏序,直接 cdq+树状数组,即可得到一个小常数 \(O(n\log^2n)\) 的做法。以下是考场代码复现:
#include<bits/stdc++.h>
using namespace std;
char ib[1<<24],*ip1=ib,*ip2=ib;
#define gc() (ip1==ip2&&(ip2=(ip1=ib)+fread(ib,1,1<<24,stdin)),ip1==ip2?EOF:*ip1++)
inline int read(){
int x=0;char c=gc();
while(c<'0'||c>'9')c=gc();
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^'0'),c=gc();
return x;
}
char ob[1<<24],*op=ob;
inline void pc(char c){
*op++=c;
}
void write(int x){
if(x>=10)write(x/10);
pc(x%10+'0');
}
void final_write(){
fwrite(ob,op-ob,1,stdout);
}
int n;
vector<int>g[500005];
int fa[19][500005],dep[500005],d[500005];
void dfs(int x,int la){
for(auto cu:g[x]){
if(cu==la)continue;
fa[0][cu]=x,dep[cu]=dep[x]+1;
dfs(cu,x);
}
}
int lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=18;i>=0;--i)if(dep[x]-dep[y]>=(1<<i))x=fa[i][x];
if(x==y)return x;
for(int i=18;i>=0;--i)if(fa[i][x]!=fa[i][y])x=fa[i][x],y=fa[i][y];
return fa[0][x];
}
namespace easy{
int sm[2000005];
void build(int l,int r,int o){
if(l==r){
sm[o]=dep[l];
return;
}
int mid=(l+r)>>1;
build(l,mid,o<<1);
build(mid+1,r,o<<1|1);
sm[o]=max(sm[o<<1],sm[o<<1|1]);
}
int query(int l,int r,int o,int ll,int rr){
if(l>=ll&&r<=rr)return sm[o];
int mid=(l+r)>>1,ans=0;
if(mid>=ll)ans=max(ans,query(l,mid,o<<1,ll,rr));
if(mid<rr)ans=max(ans,query(mid+1,r,o<<1|1,ll,rr));
return ans;
}
}
int ans[500005],l[500005],r[500005],k[500005];
int u[500005],v[500005],st[500005],tt=0;
int A[500005],B[500005],C[500005];
vector<int>gg[500005],g2[500005];
struct apple{
int x,z;
apple(int x=0,int z=0):x(x),z(z){}
bool operator<(const apple &other)const{
return x<other.x;
}
};
vector<apple>kj;
int c[500005];
inline void add(int x,int s){
while(x<=n&&c[x]<s)c[x]=s,x+=x&-x;
}
inline void clr(int x){
if(c[x])while(x<=n)c[x]=0,x+=x&-x;
}
inline int query(int x){
int ans=0;
while(x)ans=max(ans,c[x]),x-=x&-x;
return ans;
}
pair<vector<apple>,vector<apple>>solve(int l,int r){
if(l==r){
vector<apple>gl(g2[l].size()),gr(gg[l].size());
int sl=0,sr=0;
for(auto z:gg[l]){
gr[sr++]=apple(B[z],z);
}
for(auto z:g2[l]){
gl[sl++]=apple(v[z],z);
}
sort(gl.begin(),gl.end());
sort(gr.begin(),gr.end());
return make_pair(gl,gr);
}
int mid=(l+r)>>1;
auto a1=solve(l,mid),a2=solve(mid+1,r);
auto gl1=a1.first,gl2=a2.first;
int sl1=gl1.size(),sl2=gl2.size();
auto g1=a1.second,g2=a2.second;
int s1=g1.size(),s2=g2.size();
int w=0;
for(int i=0;i<sl1;++i){
while(w<s2&&g2[w].x<gl1[i].x){
int j=g2[w++].z;
ans[j]=max(ans[j],query(C[j]));
}
int j=gl1[i].z;
add(u[j]+v[j],d[j]);
}
while(w<s2){
int j=g2[w++].z;
ans[j]=max(ans[j],query(C[j]));
}
for(int i=0;i<sl1;++i){
int j=gl1[i].z;
clr(u[j]+v[j]);
}
if(l==1&&r==n)return make_pair(kj,kj);
vector<apple>gl(sl1+sl2),gr(s1+s2);
merge(gl1.begin(),gl1.end(),gl2.begin(),gl2.end(),gl.begin());
merge(g1.begin(),g1.end(),g2.begin(),g2.end(),gr.begin());
return make_pair(gl,gr);
}
int main(){
freopen("query.in","r",stdin);
freopen("query.out","w",stdout);
n=read();//1~5e5
for(int i=1;i<n;++i){
int u=read(),v=read();
g[u].emplace_back(v);
g[v].emplace_back(u);
}
dep[1]=1;dfs(1,0);
for(int i=1;i<=18;++i)for(int j=1;j<=n;++j)
fa[i][j]=fa[i-1][fa[i-1][j]];
for(int i=1;i<n;++i)d[i]=dep[lca(i,i+1)];
for(int i=1;i<n;++i){
while(tt&&d[st[tt]]>=d[i])--tt;
if(tt)u[i]=st[tt]+1;
else u[i]=1;
st[++tt]=i;
}
tt=0;
for(int i=n-1;i>=1;--i){
while(tt&&d[st[tt]]>=d[i])--tt;
if(tt)v[i]=st[tt]-1;
else v[i]=n-1;
v[i]=n-v[i];
g2[u[i]].emplace_back(i);
st[++tt]=i;
}
easy::build(1,n,1);
int q=read();
for(int i=1;i<=q;++i){
l[i]=read(),r[i]=read(),k[i]=read();
if(k[i]==1){
ans[i]=easy::query(1,n,1,l[i],r[i]);
continue;
}
--r[i],--k[i];
A[i]=max(1,r[i]-k[i]+1)+1;
B[i]=n-min(n,k[i]+l[i])+1;
C[i]=n-k[i]+1;
gg[A[i]].emplace_back(i);
/*
for(int j=l[i];j<=r[i];++j){
if(u[j]<A[i]&&v[j]<=B[i]&&u[j]+v[j]<=C[i]){
ans[i]=max(ans[i],d[j]);
}
}
*/
}
solve(1,n);
for(int i=1;i<=q;++i)write(ans[i]),pc('\n');
final_write();
return 0;
}
能不能再给力一点啊?
把三维偏序拆成二维偏序是一个很有趣的思路,而且可以证明这种形式的限制一定可以拆。
我们不妨数形结合地去考虑。将 \((u,v)\) 视为一个点,则限制是三个半平面的交,应该很容易画出来(图中可能有 \(\pm 1\) 的误差,只是说个大概形状):

可以得到的是粉色区域里所有的点。直接做是个平行六边形,就成了三维偏序。
不过我们把这个区域拆开,以 \(x=l\) 那条红线分成左右两部分,左边就是个矩形,二维偏序很容易解决,右边可以理解成一个没有上边的平行四边形,也是二维偏序,不过有一维既有上界也有下界。
代数形式就是如下两部分(可能有 \(\pm 1\) 的误差):
第一个直接扫描线+树状数组,第二个扫描线+线段树即可,复杂度是优秀的 \(O(n\log n)\)。
一种可能的实现方式:
#include<bits/stdc++.h>
using namespace std;
char ib[1<<24],*ip1=ib,*ip2=ib;
#define gc() (ip1==ip2&&(ip2=(ip1=ib)+fread(ib,1,1<<24,stdin)),ip1==ip2?EOF:*ip1++)
inline int read(){
int x=0;char c=gc();
while(c<'0'||c>'9')c=gc();
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^'0'),c=gc();
return x;
}
char ob[1<<24],*op=ob;
inline void pc(char c){
*op++=c;
}
void write(int x){
if(x>=10)write(x/10);
pc(x%10+'0');
}
void final_write(){
fwrite(ob,op-ob,1,stdout);
}
int n;
vector<int>g[500005];
int fa[19][500005],dep[500005],d[500005];
void dfs(int x,int la){
for(auto cu:g[x]){
if(cu==la)continue;
fa[0][cu]=x,dep[cu]=dep[x]+1;
dfs(cu,x);
}
}
int lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=18;i>=0;--i)if(dep[x]-dep[y]>=(1<<i))x=fa[i][x];
if(x==y)return x;
for(int i=18;i>=0;--i)if(fa[i][x]!=fa[i][y])x=fa[i][x],y=fa[i][y];
return fa[0][x];
}
namespace easy{
int sm[2000005];
void build(int l,int r,int o){
if(l==r){
sm[o]=dep[l];
return;
}
int mid=(l+r)>>1;
build(l,mid,o<<1);
build(mid+1,r,o<<1|1);
sm[o]=max(sm[o<<1],sm[o<<1|1]);
}
int query(int l,int r,int o,int ll,int rr){
if(l>=ll&&r<=rr)return sm[o];
int mid=(l+r)>>1,ans=0;
if(mid>=ll)ans=max(ans,query(l,mid,o<<1,ll,rr));
if(mid<rr)ans=max(ans,query(mid+1,r,o<<1|1,ll,rr));
return ans;
}
}
int ans[500005],l[500005],r[500005],k[500005];
int u[500005],v[500005],st[500005],tt=0;
int A[500005],B[500005],C[500005];
vector<int>g1[500005],g2[500005],g3[500005],g4[500005];
int c[500005];
void add(int x,int s){
while(x<=n){
c[x]=max(c[x],s);
x+=x&-x;
}
}
int query(int x){
int ans=0;
while(x){
ans=max(ans,c[x]);
x-=x&-x;
}
return ans;
}
namespace hard{
int sm[2000005];
void add(int l,int r,int o,int x,int y){
if(l==r){
sm[o]=max(sm[o],y);
return;
}
int mid=(l+r)>>1;
if(x<=mid)add(l,mid,o<<1,x,y);
else add(mid+1,r,o<<1|1,x,y);
sm[o]=max(sm[o<<1],sm[o<<1|1]);
}
int query(int l,int r,int o,int ll,int rr){
if(l>=ll&&r<=rr)return sm[o];
int mid=(l+r)>>1,ans=0;
if(mid>=ll)ans=max(ans,query(l,mid,o<<1,ll,rr));
if(mid<rr)ans=max(ans,query(mid+1,r,o<<1|1,ll,rr));
return ans;
}
}
int main(){
freopen("query.in","r",stdin);
freopen("query.out","w",stdout);
n=read();//1~5e5
for(int i=1;i<n;++i){
int u=read(),v=read();
g[u].emplace_back(v);
g[v].emplace_back(u);
}
dep[1]=1;dfs(1,0);
for(int i=1;i<=18;++i)for(int j=1;j<=n;++j)
fa[i][j]=fa[i-1][fa[i-1][j]];
for(int i=1;i<n;++i)d[i]=dep[lca(i,i+1)];
for(int i=1;i<n;++i){
while(tt&&d[st[tt]]>=d[i])--tt;
if(tt)u[i]=st[tt]+1;
else u[i]=1;
st[++tt]=i;
}
tt=0;
for(int i=n-1;i>=1;--i){
while(tt&&d[st[tt]]>=d[i])--tt;
if(tt)v[i]=st[tt]-1;
else v[i]=n-1;
v[i]=n-v[i];
g3[u[i]+v[i]].emplace_back(i);
g4[v[i]].emplace_back(i);
st[++tt]=i;
}
easy::build(1,n,1);
int q=read();
for(int i=1;i<=q;++i){
l[i]=read(),r[i]=read(),k[i]=read();
if(k[i]==1){
ans[i]=easy::query(1,n,1,l[i],r[i]);
continue;
}
--r[i],--k[i];
A[i]=r[i]-k[i]+1,B[i]=n-l[i]-k[i]+1;
C[i]=n-k[i]+1;
g1[C[i]].emplace_back(i);
g2[B[i]].emplace_back(i);
}
for(int i=1;i<=n;++i){
for(auto j:g4[i]){
add(u[j],d[j]);
}
for(auto j:g3[i]){
hard::add(1,n,1,u[j],d[j]);
}
for(auto j:g2[i]){
ans[j]=max(ans[j],query(l[j]));
}
for(auto j:g1[i]){
ans[j]=max(ans[j],hard::query(1,n,1,l[j],A[j]));
}
}
for(int i=1;i<=q;++i)write(ans[i]),pc('\n');
final_write();
return 0;
}

浙公网安备 33010602011771号