题解:AT_agc058_f [AGC058F] Authentic Tree DP

神题。

题意:给出一棵树,定义树 \(T\) 的权值为 \(f(T)\),其满足:

  • 对于 \(|T|=1\)\(f(T)=1\)

  • 否则,考虑其中每一条边 \((x,y)\),记断掉这条边后两棵树分别为 \(T_x,T_y\),这里不区分顺序,\(f(T) = \sum\limits_{(x,y)}\frac{f(T_x)f(T_y)}{|T|}\)

\(f(T)\)\(n\le 5000\)

做法:

感觉下手完全没啥想法,但是我们观察到一个事情,如果把代价函数里的 \(|T|\) 换成 \(|T|-1\) 那么答案就是 \(1\),因为如果你把边视为点,按照这个选边顺序写成一棵树的结构,发现就是 \(\prod \frac 1{sz}\),也就是这样一棵树的拓扑序个数,那么对于所有的边就是 \(1\) 了。

上面这个东西其实跟做法没啥特别大的关系,但是这告诉我们一点启发,就是我们不能对着这个树硬做,一个比较套路的就是类似上面的,我们给树撒一个排列,然后满足性质的排列的概率这样的。

我们考虑一个事情,如果我们给边加一个点,对于这样的所有点撒一个排列统计合法的,是不是就可以了。我们考虑从最大值删删删,那么我们就要求每次都删在边上而不在点上,一个点不能比他的邻边先删。但是这样会影响我们的子树大小,原本是 \(\frac1{sz}\) 变成 \(\frac 1{2sz-1}\) 了。

如果我们现在代表边的这个点权为 \(0\),那么我们这样撒点就比较好了,但是直接令这个点的点权为 \(0\) 并不是很好理解,我们让这个边的点挂 \(P-1\) 个点在下面,\(P\) 是本题模数,我们惊奇的发现,这样从最大值考虑,第一个删掉他的概率是 \(\frac{1}{n+P(n-1)}=\frac 1 n\),达到了我们想要的效果。

那么我们就转化成了这样一个问题:

给出一颗树,给每个点新标号,要求某些点的新标号大于相邻点的概率是多少少。

这个就很经典了,我们会做森林的情况,所以我们考虑把 \(u\to fa\) 的边容斥成没有限制减去 \(fa\to u\) 的情况,每个位置乘上子树大小的逆元即可。

但是还有个问题,这样我们的子树大小很大,怎么解决?发现这 \(P-1\) 个点,其实是我们为了让边代表点补成权为 \(0\) 才开的,所以直接让边代表的点子树大小初始为 \(0\) 即可,其余跟正常背包没区别。复杂度 \(O(n^2)\)

代码:

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 10005, mod = 998244353;
vector<int> e[maxn];
int n, dp[maxn][maxn], sz[maxn], t[maxn], inv[maxn];
int qpow(int x, int k, int p) {
	int res = 1;
	while(k) {
		if(k & 1)
			res = res * x % p;
		x = x * x % p; k >>= 1;
	}
	return res;
}
void dfs(int u, int fa) {
	sz[u] = (u <= n);
	dp[u][sz[u]] = 1;
	for (int i = 0; i < e[u].size(); i++) {
		int v = e[u][i];
		if(v == fa)
			continue;
		dfs(v, u);
		for (int x = 0; x <= sz[u]; x++)
			t[x] = dp[u][x], dp[u][x] = 0;
		for (int x = 0; x <= sz[u]; x++)
			for (int y = 0; y <= sz[v]; y++) {
				if(u <= n) {
					dp[u][x + y] = (dp[u][x + y] - t[x] * dp[v][y] % mod + mod) % mod;
					dp[u][x] = (dp[u][x] + t[x] * dp[v][y] % mod) % mod;
				}
				else 
					dp[u][x + y] = (dp[u][x + y] + t[x] * dp[v][y] % mod + mod) % mod;
			}
		sz[u] += sz[v];
	}
	for (int i = 0; i <= sz[u]; i++)
		dp[u][i] = dp[u][i] * inv[i] % mod;
}
signed main() {
	cin >> n;
	for (int i = 1; i < n; i++) {
		int x, y; cin >> x >> y;
		e[x].push_back(i + n);
		e[y].push_back(i + n);
		e[i + n].push_back(x);
		e[i + n].push_back(y);
	}
	inv[0] = inv[1] = 1;
	for (int i = 2; i <= 2 * n; i++)
		inv[i] = (mod - mod / i) * inv[mod % i] % mod;
	dfs(1, 0);
	int ans = 0;
	for (int i = 0; i <= sz[1]; i++)
		ans = (ans + dp[1][i]) % mod;
	cout << ans << endl;
	return 0;
}
posted @ 2026-01-23 20:48  LUlululu1616  阅读(5)  评论(1)    收藏  举报