题解:[湖南集训] 更为厉害

题意分析

\(q\) 次询问,每次询问给出树上节点 \(a\) 和常数 \(k\),求满足 \(a,b\) 均为 \(c\) 的祖先且 \(\operatorname{dist}(a,b)\leq k\) 的方案数。

可以发现,\(c\) 一定在 \(a\) 的子树内;那么 \(b\) 相对于 \(a\) 的位置关系为:

  • \(b\)\(a\) 的祖先。

    \(c\) 可以在子树 \(a\) 内随便取(除了 \(a\)),答案即 \((\textit{size}_a-1)\min(\textit{depth}_a-1,k)\)

  • \(b\)\(a\) 的子树内。

    那么 \(c\)\(b\) 的子树内随便取(除了 \(b\)),答案即:

    \(B\) 为子树 \(a\) 内所有深度在 \(\left[\textit{depth}_a+1,\textit{depth}_a+k\right]\) 范围内的点的集合。

    则这一部分的答案为:

    \[\sum_{b\in B}(\textit{size}_b-1) \]

考虑快速求出 \(b\)\(a\) 的子树内的情况的答案,需要维护子树信息。

又因为需要区间查询深度范围内的点,考虑线段树,因此想到线段树维护点的贡献(\(\textit{size}_b-1\))和。对于子树信息,DFS 序前缀和即可。

那么给每一个点都开一个可持久化权值线段树存起来即可,时间轴为 DFS 序。

AC 代码

//#include<bits/stdc++.h>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<iomanip>
#include<cstdio>
#include<string>
#include<vector>
#include<cmath>
#include<ctime>
#include<deque>
#include<queue>
#include<stack>
#include<list>
using namespace std;
typedef long long ll;
constexpr const int N=3e5;
int n,father[N+1],dfn[N+1],size[N+1],depth[N+1];
vector<int>g[N+1];
struct PersistentSegTree{
	int size,root[N+1];
	struct node{
		int l,r;
		int lChild,rChild;
		ll sum;
	}t[N*80+1];
	
	int create(node x){
		t[++size]=x;
		return size;
	}
	int clone(int p){
		t[++size]=t[p];
		return size;
	}
	int build(int l,int r){
		int p=create({l,r});
		if(l==r){
			return p;
		}
		int mid=l+r>>1;
		t[p].lChild=build(l,mid);
		t[p].rChild=build(mid+1,r);
		return p;
	}
	void up(int p){
		t[p].sum=t[t[p].lChild].sum+t[t[p].rChild].sum;
	}
	int add(int p,int x,int k){
		p=clone(p);
		if(t[p].l==t[p].r){
			t[p].sum+=k;
			return p;
		}
		if(x<=t[t[p].lChild].r){
			t[p].lChild=add(t[p].lChild,x,k);
		}else{
			t[p].rChild=add(t[p].rChild,x,k);
		}
		up(p);
		return p;
	}
	void add(int v,int i,int x,int k){
		root[i]=add(root[v],x,k);
	}
	ll query0(int p,int q,int l,int r){
		if(l<=t[p].l&&t[p].r<=r){
			return t[p].sum-t[q].sum;
		}
		ll ans=0;
		if(l<=t[t[p].lChild].r){
			ans=query0(t[p].lChild,t[q].lChild,l,r);
		}
		if(t[t[p].rChild].l<=r){
			ans+=query0(t[p].rChild,t[q].rChild,l,r);
		}
		return ans;
	}
	ll query(int v,int i,int l,int r){
		if(r<1||n<l){
			return 0;
		}
		return query0(root[i],root[v-1],l,r);
	}
}t;
void dfs(int x,int fx){
	father[x]=fx;
	depth[x]=depth[fx]+1;
	size[x]=1;
	static int cnt;
	dfn[x]=++cnt;
	for(int i:g[x]){
		if(i==fx){
			continue;
		}
		dfs(i,x);
		size[x]+=size[i];
	}
}
void dfs2(int x){
	t.add(dfn[x]-1,dfn[x],depth[x],size[x]-1);
	for(int i:g[x]){
		if(i==father[x]){
			continue;
		}
		dfs2(i);
	}
}
void pre(){
	t.root[0]=t.build(1,n);
	dfs(1,0);
	dfs2(1);
}
ll query(int x,int k){
	ll ans=(size[x]-1ll)*min(depth[x]-1,k);
	ans+=t.query(dfn[x],dfn[x]+size[x]-1,depth[x]+1,depth[x]+k);
	return ans;
}
int main(){
	/*freopen("test.in","r",stdin);
	freopen("test.out","w",stdout);*/
	
	ios::sync_with_stdio(false);
	cin.tie(0);cout.tie(0);
	
	int q;
	cin>>n>>q;
	for(int i=1;i<n;i++){
		int u,v;
		cin>>u>>v;
		g[u].push_back(v);
		g[v].push_back(u);
	}
	pre();
	while(q--){
		int p,k;
		cin>>p>>k;
		cout<<query(p,k)<<'\n';
	}
	
	cout.flush();
	
	/*fclose(stdin);
	fclose(stdout);*/
	return 0;
}
posted @ 2026-02-08 15:51  TH911  阅读(6)  评论(0)    收藏  举报