[NOIP2024] 树上查询
原题传送门
Solution 1
考虑 \(O(nq)\) 的解法。
性质 \(1\):选择长度恰好等于 \(k\) 的区间 \(Lca\) 一定最优。
证明:新加入一个点,\(dep_{lca}\) 不增,证毕。
所以对于每次询问,可以枚举左端点,就可以同时固定右端点。使用 \(st\) 表可以维护区间 \(Lca\),使用 \(dfn\) 序求 \(Lca\),可以将时间复杂度变为预处理 \(O(n \log n)\),查询 \(O(1)\)。
发现这样做可以顺便通过特殊性质 \(B\),因为该性质每次询问的左端点只有 \(1\) 个,时间复杂度变为 \(O(q)\)。
期望得分 \(32\)。
Solution 2
考虑特殊性质 \(A\) 的解法。
此时 \(dep_{Lca^*[l,r]}=\min_{i=l}^r dep[i]\),则原询问转化为序列问题。
考虑一个点 \(x\) 对答案有贡献的条件是什么。首先是在当前区间中必须为最小值,所以先用单调栈预处理点 \(x\) 左边和右边第一个小于其的数,注意不是小于等于。由于区间长度越长,其能造成贡献的询问个数就越多,所以应找到一个小于而非小于等于其的数。
设左边第一个大于其的位置为 \(l\),右边为 \(r\)。则其能支配的区间长度为 \(r-l-1\),设其为 \(len\)。
此时其能造成贡献的条件即为 \(ql \le x \le qr\) 且询问的区间与其能支配的区间交长度大于等于 \(k\)。
此时分为两种情况进行讨论:
- \(l < ql\),此时要想交集长度大于等于 \(k\),则 \(r \ge ql+k-1\)。但此时不一定保证 \(ql \le x \le qr\),可若 \(x\) 不在询问区间中,且其能支配的区间与询问区间的交集长度大于等于 \(k\),则考虑其支配的区间与询问区间的交集,其长度大于等于 \(k\),且其中的每个数都大于等于 \(dep_x\),故答案一定比 \(dep_x\) 大,故可以扩展解域。
此时可以对 \(l\) 扫描线,线段树维护 \(r\),点 \(x\) 造成的贡献即为 \(dep_x\)。将序列倒过来,对 \(r\) 扫描线,线段树维护 \(l\),即可以处理 \(r > qr\) 的贡献。
- \(l \ge ql\),此时如果仍然对 \(l\) 扫描线,线段树维护 \(r\) 是不可行的。上一种情况之所以可行,是因为固定了一个点,若 \(l < ql\),则固定交集区间的左端点为 \(ql\);若 \(r > qr\),则固定了交集区间的右端点为 \(qr\),此时便可以线段树维护。但此时我们需要维护交集的两个点,考虑进一步的加强限制。
首先,若点 \(x\) 支配的区间长度 \(len\) 小于 \(qk\),肯定不能对该询问产生贡献。所以只需考虑 \(len \ge qk\) 的区间即可。此时如果 \(ql \le l \le r-k+1\),则一定能产生贡献。可以这样想,如果 \(ql \le l \le r \le qr\),则由于 \(len \ge k\),一定可行;若 \(ql \le l \le r \le r\),则交集右端点固定,当 \(l \le r-k+1\) 时,依旧可以产生贡献。
发现当 \(len \ge k\) 的时候,右端点可以直接调到 \(qr\) 的位置,因为此时决定区间是否有贡献的因素不是 \(r\);当然你也可以将左端点调到 \(ql\) 的位置,本质上都是固定了一个端点。
所以对 \(k\) 进行扫描线,线段树维护 \(l\)。发现此时维护的情况也包括了 \(r > qr\),所以不需要对序列倒过来求解一遍了。
原来的问题就变为了两个简单的问题。
期望时间复杂度 \(O((n+q) \log n)\),期望得分 \(64\)。
Solution 3
由于性质 \(B\) 的缘故,我们将其转化为了一个序列问题,此时就可以考虑将一般情况下的 \(Lca\) 也转化为序列问题。
性质 \(2\):\(dep_{Lca^*[l,r]}=\min_{i=l}^{r-1} dep_{Lca(i,i+1)}\)。
证明:考虑原先的 \(Lca^*[l,r]\),若想产生其,则必定有两个点来自于该点的不同子树,则这两个点之间也必定存在两点相邻且来自原 \(Lca\) 的不同子树,从而这两个点的 \(Lca\) 即为原 \(Lca\)。由于 \(dep_{Lca}\) 是不增的,所以这两个点的 \(dep_{Lca}\) 不可能小于原 \(Lca\) 的深度。证毕。
则将 \(k=1\) 特判掉之后,可以将区间 \(Lca\) 转化为区间求 \(min\) 的问题即 Solution 2。
期望时间复杂度 \(O((n+q) \log n)\),期望得分 \(100\)。
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INT_MAX (int)(1e18)
#define mid (l+r>>1)
const int N=5e5+10;
int n,q,idx,dn;
int dfn[N],dep[N],lg[N];
int head[N],nxt[2*N],ver[2*N];
int st[20][N],st1[20][N];
int ru[N],ans[N],a[N];
struct node{
int l,r,k,id;
}ask[N];
inline int read(){
int t=0,f=1;
register char c=getchar();
while(c<'0'||c>'9') f=(c=='-')?(-1):(f),c=getchar();
while(c>='0'&&c<='9') t=(t<<3)+(t<<1)+(c^48),c=getchar();
return t*f;
}
void add(int u,int v){
nxt[++idx]=head[u];
head[u]=idx;
ver[idx]=v;
ru[v]++;
}
void dfs(int u,int v){
dep[u]=dep[v]+1,dfn[u]=++dn,st1[0][dfn[u]]=v;
for(int i=head[u];i;i=nxt[i]){
int dao=ver[i];
if(dao==v) continue;
dfs(dao,u);
}
}
int get(int x,int y){
return dfn[x]<dfn[y]?x:y;
}
int calc(int l,int r){
if(l==r) return l;
l=dfn[l],r=dfn[r];
if(l>r) swap(l,r);
int x=lg[r-l];l++;
return get(st1[x][l],st1[x][r-(1<<x)+1]);
}
int calc1(int l,int r){
int x=lg[r-l+1];
return calc(st[x][l],st[x][r-(1<<x)+1]);
}
void init(){
for(int i=2;i<=n;i++) lg[i]=lg[i>>1]+1;
for(int i=1;i<=lg[n];i++)
for(int j=1;j+(1<<i)-1<=n;j++)
st1[i][j]=get(st1[i-1][j],st1[i-1][j+(1<<i-1)]);
for(int i=1;i<=n;i++) st[0][i]=i;
for(int i=1;i<=lg[n];i++)
for(int j=1;j+(1<<i)-1<=n;j++)
st[i][j]=calc(st[i-1][j],st[i-1][j+(1<<i-1)]);
}
void solve1(){
init();
for(int i=1;i<=q;i++){
int l=ask[i].l,r=ask[i].r,k=ask[i].k;int ans=0;
for(int i=l;i+k-1<=r;i++) ans=max(ans,dep[calc1(i,i+k-1)]);
cout<<ans<<"\n";
}
}
stack<pair<int,int> > s;
int lef[N],rig[N],tr[N<<2];
vector<pair<int,int> > mes[N];
bool cmp(node x,node y){return x.k>y.k;}
bool cmp1(node x,node y){return x.l<y.l;}
void init1(){
for(int i=2;i<=n;i++) lg[i]=lg[i>>1]+1;
for(int i=1;i<=lg[n];i++)
for(int j=1;j+(1<<i)-1<=n;j++)
st1[i][j]=get(st1[i-1][j],st1[i-1][j+(1<<i-1)]);
for(int i=1;i<=n;i++) st[0][i]=dep[i];
for(int i=1;i<=lg[n];i++)
for(int j=1;j+(1<<i)-1<=n;j++)
st[i][j]=max(st[i-1][j],st[i-1][j+(1<<i-1)]);
for(int i=1;i<n;i++) a[i]=dep[calc(i,i+1)];
// cout<<"dep:\n";
// for(int i=1;i<n;i++) cout<<a[i]<<" ";
// cout<<"\n";
s.push({0,-INT_MAX});
for(int i=1;i<n;i++){
while(!s.empty()&&s.top().second>=a[i]) s.pop();
lef[i]=s.top().first;s.push({i,a[i]});
}
while(!s.empty()) s.pop();
s.push({n,-INT_MAX});
for(int i=n-1;i>=1;i--){
while(!s.empty()&&s.top().second>=a[i]) s.pop();
rig[i]=s.top().first;s.push({i,a[i]});
}
// for(int i=1;i<=n;i++) cout<<lef[i]<<" "<<rig[i]<<"\n";
}
int calc2(int l,int r){
int x=lg[r-l+1];
return max(st[x][l],st[x][r-(1<<x)+1]);
}
void build(int bian,int l,int r){
tr[bian]=0;
if(l==r) return;
build(bian<<1,l,mid);
build(bian<<1|1,mid+1,r);
}
void update(int bian,int l,int r,int x,int y){
tr[bian]=max(tr[bian],y);
if(l==r) return;
if(x<=mid) update(bian<<1,l,mid,x,y);
else update(bian<<1|1,mid+1,r,x,y);
}
int query(int bian,int l,int r,int L,int R){
if(L<=l&&R>=r) return tr[bian];
int maxx=0;
if(L<=mid) maxx=max(maxx,query(bian<<1,l,mid,L,R));
if(R>mid) maxx=max(maxx,query(bian<<1|1,mid+1,r,L,R));
return maxx;
}
void solve2(){
for(int i=1;i<n;i++) mes[rig[i]-lef[i]-1].push_back({lef[i]+1,a[i]});
// cout<<"mes:\n";
// for(int i=1;i<n;i++)
// for(pair<int,int> j:mes[i]) cout<<j.first<
for(int i=1;i<=q;i++) ask[i].r--,ask[i].k--;
sort(ask+1,ask+1+q,cmp);ask[0].k=n;
for(int i=1;i<=q;i++){
if(ask[i].k==0){ans[ask[i].id]=calc2(ask[i].l,ask[i].r+1);continue;}
for(int j=ask[i-1].k-1;j>=ask[i].k;j--)
for(pair<int,int> k:mes[j])
update(1,1,n,k.first,k.second);
ans[ask[i].id]=query(1,1,n,ask[i].l,ask[i].r-ask[i].k+1);
}//求出 l>=ql 的情况
for(int i=1;i<=n;i++) while(!mes[i].empty()) mes[i].pop_back();
for(int i=1;i<=n;i++) mes[lef[i]+1].push_back({rig[i]-1,a[i]});
sort(ask+1,ask+1+q,cmp1);build(1,1,n);ask[0].l=1;
for(int i=1;i<=q;i++){
for(int j=ask[i-1].l;j<ask[i].l;j++)
for(pair<int,int> k:mes[j])
update(1,1,n,k.first,k.second);
if(ask[i].k==0) continue;
ans[ask[i].id]=max(ans[ask[i].id],query(1,1,n,ask[i].l+ask[i].k-1,n));
}
for(int i=1;i<=q;i++) cout<<ans[i]<<"\n";
}
void input(){
n=read();
for(int i=1;i<n;i++){
int u=read(),v=read();
add(u,v);add(v,u);
}
bool flag1=false,flag2=true;
int sum1=0,sum2=0;
for(int i=1;i<=n;i++)
if(ru[i]==1) sum1++;
else if(ru[i]==2) sum2++;
if(sum1==2&&sum2==n-2) flag1=true;
q=read();
for(int i=1;i<=q;i++){
ask[i].l=read(),ask[i].r=read(),ask[i].k=read(),ask[i].id=i;
if(ask[i].k!=ask[i].r-ask[i].l+1) flag2=false;
}
dfs(1,0);init1();
solve2();
// if(n<=5000||flag2) solve1();
// else if(flag1) solve2();
}
signed main(){
input();
return 0;
}

浙公网安备 33010602011771号