【NOI P模拟赛】最短路(树形DP,树的直径)

题面

给定一棵 n n n 个结点的无根树,每条边的边权均为 1 1 1

树上标记有 m m m 个互不相同的关键点,小 A \tt A A 会在这 m m m 个点中等概率随机地选择 k k k 个不同的点放上小饼干。你想知道,经过有小饼干的 k k k 个点的最短路径长度的期望是多少。注意,你可以任意选取起点和终点,路径也可以经过重复的点或重复的边。

2 ≤ k ≤ m ≤ n ≤ 2000 2\leq k\leq m\leq n\leq2000 2kmn2000

m ≤ 300 m\leq 300 m300 ,但是没必要。

题解

我们在脑海中建出这 k k k 个点构成的虚树,任意地改变树的形态,会发现,起点和终点一定是某条直径的两端。我们令别的点的 d f s \rm dfs dfs 序都排在终点前面,然后根据 d f s \rm dfs dfs 序访问结点。这样,除了直径上的边只经过一次,虚树上别的边都经过了两次。所以“最短路径”等于 边权和×2直径长度

因此,最短路径长度的期望可以分开来求,先求边权和的期望,再减去直径长度的期望。

边权和的期望很简单,可以树上背包DP,也可以枚举每条边的贡献,前者 O ( n 2 ) O(n^2) O(n2) ,后者 O ( n ) O(n) O(n)

考虑到总方案数是已知的,我们求所有方案下直径长度的和,可得直径长度的期望。

我们知道,边权为 1 的树的直径有些特点:

  • 若直径长度为偶数,则所有直径的中心点一定相同。
  • 若直径长度为奇数,则所有直径的中心边一定相同。

因此,我们枚举中心点/边,再枚举直径长度计算,以中心点为例:

令中心点为 x x x ,我们从 x x x 的每个儿子出发,预处理出每个儿子为根的子树中深度为 i ( i ≥ 1 ) i(i\geq 1) i(i1) 的关键点个数 c n t [ i ] cnt[i] cnt[i] ,每个子树分开来。 c n t [ 0 ] cnt[0] cnt[0] 表示 x x x 是否为关键点。

若要保证直径长度为 2 l 2l 2l ,则所有 k k k 个放上小饼干的点的深度都得小于等于 l l l ,同时深度等于 l l l 的点至少能选出两个属于不同的子树。我们可以统计所有深度小于 l l l 的点数 c n t p cntp cntp (包括 c n t [ 0 ] cnt[0] cnt[0]),第 i i i 个子树深度等于 l l l 的点数为 c n t i [ l ] cnt_i[l] cnti[l] ,用一点小容斥,总数减去 l l l 层无点方案再减去 l l l 层只有一个子树有点方案。那么直径长度为 2 l 2l 2l 的方案数为
( c n t p + ∑ c n t i [ l ] k ) − ( c n t p k ) − ∑ ( ( c n t p + c n t i [ l ] k ) − ( c n t p k ) ) {cntp+\sum cnt_i[l]\choose k}-{cntp\choose k}-\sum\left({cntp+cnt_i[l]\choose k}-{cntp\choose k}\right) (kcntp+cnti[l])(kcntp)((kcntp+cnti[l])(kcntp))

中心边的算法类似,相当于只有两个子树,还要简单许多。

时间复杂度 O ( n 2 ) O(n^2) O(n2)

CODE

#include<set>
#include<map>
#include<cmath>
#include<stack>
#include<random>
#include<vector>
#include<bitset>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 2005
#define LL long long
#define ULL unsigned long long
#define DB double
#define lowbit(x) (-(x) & (x))
#define ENDL putchar('\n')
#define FI first
#define SE second
LL read() {
    LL f=1,x=0;int s = getchar(); 
    while(s < '0' || s > '9') {if(s<0)return -1;if(s=='-')f=-f;s = getchar();}
    while(s >= '0' && s <= '9') {x = (x<<3) + (x<<1) + (s^48); s = getchar();}
    return f*x;
}
void putpos(LL x) {if(!x)return ;putpos(x/10);putchar('0'+(x%10));}
void putnum(LL x) {
    if(!x) {putchar('0');return ;}
    if(x<0) {putchar('-');x = -x;}
    return putpos(x);
}
void AIput(LL x,int c) {putnum(x);putchar(c);}

const int MOD = 998244353;
int n,m,s,o,k;
int U[MAXN],V[MAXN];
int fac[MAXN],inv[MAXN],invf[MAXN];
int C(int n,int m) {
	if(m < 0 || m > n) return 0;
	return fac[n] *1ll* invf[n-m] % MOD * invf[m] % MOD;
}
int invC(int n,int m) {
	if(m < 0 || m > n) return 1;
	return invf[n] *1ll* fac[n-m] % MOD * fac[m] % MOD;
}
int hd[MAXN],v[MAXN<<1],nx[MAXN<<1],cne;
void ins(int x,int y) {
	nx[++ cne] = hd[x]; v[cne] = y; hd[x] = cne;
}
int ans = 0;
bool f[MAXN];
int dp[MAXN][MAXN],cc[MAXN][MAXN],sz[MAXN];
void dfs0(int x,int ff) {
	cc[x][0] = 1;
	sz[x] = f[x];
	if(f[x]) cc[x][1] = 1;
	int dpp = 0;
	for(int i = hd[x];i;i = nx[i]) {
		int y = v[i];
		if(y != ff) {
			dfs0(y,x);
			sz[x] += sz[y];
			(dpp += MOD-(dp[y][k]+cc[y][k])%MOD) %= MOD;
			for(int j = sz[x];j > 0;j --) {
				for(int s = 1;s <= sz[y] && s <= j;s ++) {
					(dp[x][j] += (dp[x][j-s]*1ll*cc[y][s] + dp[y][s]*1ll*cc[x][j-s] + cc[x][j-s]*1ll*cc[y][s]) % MOD) %= MOD;
					(cc[x][j] += cc[x][j-s]*1ll*cc[y][s] % MOD) %= MOD;
				}
			}
		}
	}
	(dpp += dp[x][k]) %= MOD;
	(ans += dpp) %= MOD;
	return ;
}
vector<int> bu[MAXN];
int ct[MAXN],mx;
void dfs(int x,int ff,int d) {
	if(f[x]) {
		mx = max(mx,d);
		ct[d] ++;
	}
	for(int i = hd[x];i;i = nx[i]) {
		int y = v[i];
		if(y != ff) {
			dfs(y,x,d+1);
		}
	}return ;
}
int main() {
	freopen("tree.in","r",stdin);
	freopen("tree.out","w",stdout);
	n = read(); m = read(); k = read();
	fac[0]=fac[1]=inv[0]=inv[1]=invf[0]=invf[1]=1;
	for(int i = 2;i <= n;i ++) {
		fac[i] = fac[i-1] *1ll* i % MOD;
		inv[i] = (MOD - inv[MOD%i]) *1ll* (MOD/i) % MOD;
		invf[i] = invf[i-1] *1ll* inv[i] % MOD;
	}
	for(int i = 1;i <= m;i ++) {
		f[read()] = 1;
	}
	for(int i = 1;i < n;i ++) {
		s = read();o = read();
		U[i] = s; V[i] = o;
		ins(s,o); ins(o,s);
	}
	dfs0(1,0);
	ans = ans*2ll % MOD;
	for(int i = 1;i < n;i ++) {
		s = U[i];o = V[i];
		for(int j = 0;j <= n;j ++) bu[j].clear();
		dfs(s,o,0);
		for(int j = 0;j <= n;j ++) {
			bu[j].push_back(ct[j]);
			ct[j] = 0;
		}
		dfs(o,s,0);
		for(int j = 0;j <= n;j ++) {
			bu[j].push_back(ct[j]);
			ct[j] = 0;
		}
		int cnt = 0;
		for(int j = 0;j <= n;j ++) {
			int le = (j<<1|1);
			int A = bu[j][0],B = bu[j][1];
			int nm = (0ll+ C(cnt+A+B,k) +MOD- C(cnt+A,k) +MOD- C(cnt+B,k) + C(cnt,k)) % MOD;
			(ans += MOD-nm*1ll*le%MOD) %= MOD;
			cnt += A+B;
		}
	}
	for(int i = 1;i <= n;i ++) {
		for(int j = 0;j <= n;j ++) bu[j].clear();
		for(int y = hd[i];y;y = nx[y]) {
			mx = 0;
			dfs(v[y],i,1);
			for(int j = 1;j <= mx;j ++) {
				if(ct[j]) {
					bu[j].push_back(ct[j]);
					ct[j] = 0;
				}
			}
		}
		int cnt = f[i];
		for(int j = 1;j <= n;j ++) {
			int le = (j<<1),dt = 0;
			for(int s = 0;s < (int)bu[j].size();s ++) {
				dt += bu[j][s];
			}
			int O = C(cnt,k);
			int nm = (C(cnt+dt,k) +MOD- O) % MOD;
			for(int s = 0;s < (int)bu[j].size();s ++) {
				(nm += MOD- (C(cnt+bu[j][s],k)+MOD-O)%MOD) %= MOD;
			}
			(ans += MOD-nm*1ll*le%MOD) %= MOD;
			cnt += dt;
		}
	}
	ans = ans *1ll* invC(m,k) % MOD;
	AIput(ans,'\n');
	return 0;
}

赛时样例太水了,导致我 c n t p cntp cntp 未更新的错误没被发现,只有 20 分。
在这里插入图片描述

posted @ 2021-11-02 16:31  DD_XYX  阅读(48)  评论(0编辑  收藏  举报