目录
题目
解法
\(\text{Sol 1}\):\(\rm dfs\) 序
用 \(0\) 号点将森林连成一棵树。问题可以转化成 "某个点有多少个 \(k\) 级子孙",可以以深度为下标将点以 \(\rm dfs\) 序塞进 \(\rm vector\) 里,查询时在对应深度的 \(\rm vector\) 里二分此点对应子树的 \(\rm dfs\) 序区间。
\(\text{Sol 2}\):\(\text{dsu on tree}\) + 线段树合并
类似 \(\text{CodeForces - 600E Lomsat gelral}\),直接在子树中修改答案即可,时间复杂度是一样的。
代码
\(\text{dsu on tree}\)
#include <cstdio>
#include <vector>
using namespace std;
const int N = 1e5 + 5;
int n, m, f[N][20], lg[N], siz[N], dep[N], sum[N], son[N], len[N], be[N];
bool vis[N];
vector <int> vec[N], ans[N], q[N];
int read() {
    int x = 0, f = 1; char s;
    while((s = getchar()) < '0' || s > '9') if(s == '-') f = -1;
    while(s >= '0' && s <= '9') {x = (x << 1) + (x << 3) + (s ^ 48); s = getchar();}
    return x * f;
}
void init() {
    for(int i = 1; i <= 18; ++ i)
        for(int j = 1; j <= n; ++ j)
            f[j][i] = f[f[j][i - 1]][i - 1];
    for(int i = 2; i <= n; ++ i) lg[i] = lg[i >> 1] + 1;
}
void dfs(const int u) {
    siz[u] = 1;
    for(int i = 0, sz = vec[u].size(); i < sz; ++ i) {
        int v = vec[u][i];
        dep[v] = dep[u] + 1;
        dfs(v);
        siz[u] += siz[v];
        if(siz[son[u]] < siz[v]) son[u] = v;
    }
}
void cal(const int u, const int k) {
    sum[dep[u]] += k;
    for(int i = 0, sz = vec[u].size(); i < sz; ++ i) {
        int v = vec[u][i];
        if(vis[v]) continue;
        cal(v, k);
    }
}
void solve(const int u, const int fa, const int keep) {
    for(int i = 0, sz = vec[u].size(); i < sz; ++ i) {
        int v = vec[u][i];
        if(v == son[u]) continue;
        solve(v, u, 0);
    }
    if(son[u]) solve(son[u], u, 1), vis[son[u]] = 1;
    cal(u, 1);
    if(son[u]) vis[son[u]] = 0;
    for(int i = 0, sz = q[u].size(); i < sz; ++ i) ans[u].push_back(sum[q[u][i]]);
    if(! keep) cal(u, -1);
}
int main() {
    int x, y, z;
    n = read();
    for(int i = 1; i <= n; ++ i) f[i][0] = read(), vec[f[i][0]].push_back(i);
    init(); dfs(0);
    m = read();
    for(int i = 1; i <= m; ++ i) {
        x = read(), y = read();
        if(y >= dep[x]) be[i] = 0; // 因为有超级源点的存在,所以正式的 dep 从一开始
        else {
            z = dep[x];
            while(y) x = f[x][lg[y]], y -= (1 << lg[y]);
            q[x].push_back(z); be[i] = x;
        }
    }
    solve(0, -1, 1);
    for(int i = 1; i <= m; ++ i)
        if(be[i]) printf("%d ", ans[be[i]][len[be[i]] ++] - 1); // 除去自己
        else printf("0 ");
	return 0;
}
线段树合并
#include <cstdio>
#define print(x,y) write(x),putchar(y)
template <class T>
inline T read(const T sample) {
	T x=0; char s; bool f=0;
	while((s=getchar())>'9' or s<'0')
		f|=(s=='-');
	while(s>='0' and s<='9')
		x=(x<<1)+(x<<3)+(s^48),
		s=getchar();
	return f?-x:x;
}
template <class T>
inline void write(const T x) {
	if(x<0) {
		putchar('-');
		write(-x);
		return;
	}
	if(x>9) write(x/10);
	putchar(x%10^48);
}
#include <vector>
using namespace std;
const int maxn=1e5+5;
int n,f[maxn][20],dep[maxn],idx;
int rt[maxn],ans[maxn];
struct node {
	int ls,rs,v;
} t[maxn*80];
vector <int> e[maxn],Q[maxn];
void init(int u) {
	for(int i=1;i<=18;++i)
		f[u][i]=f[f[u][i-1]][i-1];
	for(auto v:e[u]) {
		dep[v]=dep[u]+1;
		init(v);
	}
}
int jump(int x,int k) {
	int o=x;
	for(int i=18;i>=0;--i)
		if(dep[x]-dep[f[o][i]]<=k)
			o=f[o][i];
	return o;
}
void pushUp(int o) {
	t[o].v=t[t[o].ls].v+t[t[o].rs].v;
}
int merge(int x,int y,int l,int r) {
	if(!x or !y) return x|y;
	if(l==r) return t[x].v+=t[y].v,x;
	int mid=l+r>>1;
	t[x].ls=merge(t[x].ls,t[y].ls,l,mid);
	t[x].rs=merge(t[x].rs,t[y].rs,mid+1,r);
	pushUp(x);
	return x;
}
void ins(int &o,int l,int r,int p,int k) {
	if(!o) o=++idx;
	if(l==r) return t[o].v+=k,void();
	int mid=l+r>>1;
	if(p<=mid) ins(t[o].ls,l,mid,p,k);
	else ins(t[o].rs,mid+1,r,p,k);
	pushUp(o);
} 
int ask(int o,int l,int r,int p) {
	if(l==r) return t[o].v;
	int mid=l+r>>1;
	if(p<=mid) return ask(t[o].ls,l,mid,p);
	else return ask(t[o].rs,mid+1,r,p);
}
void dfs(int u) {
	for(auto v:e[u]) {
		dfs(v);
		rt[u]=merge(rt[u],rt[v],0,n);
	}
	ins(rt[u],0,n,dep[u],1);
	for(auto i:Q[u])
		if(u)
			ans[i]=ask(rt[u],0,n,ans[i])-1;
		else ans[i]=0;
}
int main() {
	n=read(9);
	for(int i=1;i<=n;++i)
		e[f[i][0]=read(9)].push_back(i);
	init(0);
	int q=read(9);
	for(int i=1;i<=q;++i) {
		int x,y;
		x=read(9),y=read(9);
		int fa=jump(x,y);
		Q[fa].push_back(i);
		ans[i]=y+dep[fa];
	}
	dfs(0);
	for(int i=1;i<=q;++i)
		print(ans[i],' ');
	puts("");
	return 0;
}
 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号