CF1097G Vladislav and a Great Legend(斯特林数+树上背包)
题目:洛谷CF1097G、CF1097G
题目描述:
给了一棵\(n\)个点的树,对于树上的一个点集\(X\),令\(f(X)\)为点集\(X\)所形成的最小联通子树的边数,求这棵树所有\(X\)的\((f(X))^k\)之和
\(n \leq 10^5\),\(k \leq 200\)
蒟蒻题解:
看到\(k\)次方,我们考虑斯特林数:
\[\sum_{X} (f(X))^k = \sum_{X} \sum_{i=0}^{k} \begin{Bmatrix}k\\i\end{Bmatrix} (f(X))^{\underline{i}} = \sum_{X} \sum_{i=0}^{k} \begin{Bmatrix}k\\i\end{Bmatrix} i! \binom{f(x)}{i} = \sum_{i = 0}^k \begin{Bmatrix}k\\i\end{Bmatrix} i! \sum_{X} \binom{f(x)}{i}
\]
对于\(\begin{Bmatrix}k\\i\end{Bmatrix}\),我们可以直接\(k^2\)暴力求出来
观察到\(\binom{f(x)}{i}\)表示在最小联通子树的\(f(x)\)条边中任选\(i\)条的方案数,看起来可以用树上\(dp\)解决
设\(f_{x,i}\)表示在以\(x\)为根的子树内,所有的非空最小连通块选出其中\(i\)条边的方案数,令\(ans_i\)表示\(\sum_{X} \binom{f(x)}{i}\)
令当前的根节点为\(x\),\(u\)为\(x\)的一个儿子节点,要把\(u\)合并上来,\(f_{x,i}\)有以下几种情况:
- 合并前以\(x\)为根的子树
- 合并前以\(u\)为根的子树
- 将以\(x\)为根的子树和以\(u\)为根的子树合并起来,要注意考虑是否把\(x\)->\(u\)这条边合并上来的情况\((f_{x,i}+=f_{x,j}f_{u,i-j}+f_{x,j}f_{u,i-j-1})\),并把这个答案加进\(ans_i\)里
打完发现样例都过不了,想想漏了什么情况
由于一棵最小联通子树的根节点不一定要选,所以要把没有连到根的部分也传递上去,所以第\(2\)中情况应该改为:合并前以\(u\)为根的子树,并加上选\(x\)->\(u\)这条边(这部分不能加进\(ans_i\)内)
这样就可以用树上背包解决了,树上背包的时间复杂度是\(\mathcal O(nk)\)的
总的时间复杂度\(\mathcal O(m^2 + nk)\)
参考程序:
#include<bits/stdc++.h>
using namespace std;
#define Re register int
typedef long long ll;
const int N = 100005, M = 205, p = 1e9 + 7;
int n, k, cnt, s, hea[N], nxt[N << 1], to[N << 1], S[M][M], siz[N], f[N][M], g[M], ans[M];
inline int read()
{
char c = getchar();
int ans = 0;
while (c < 48 || c > 57) c = getchar();
while (c >= 48 && c <= 57) ans = (ans << 3) + (ans << 1) + (c ^ 48), c = getchar();
return ans;
}
inline void add(int x, int y)
{
nxt[++cnt] = hea[x], to[cnt] = y, hea[x] = cnt;
}
inline int inc(int x, int y)
{
x += y;
return x < p ? x : x - p;
}
inline void dfs(int x, int fa)
{
siz[x] = f[x][0] = 1;
for (Re i = hea[x]; i; i = nxt[i])
{
int u = to[i];
if (u == fa) continue;
dfs(u, x);
g[0] = inc(f[x][0], f[u][0]);
for (Re j = 1; j <= k && j < siz[x] + siz[u]; ++j) g[j] = inc(f[x][j], inc(f[u][j], f[u][j - 1]));
for (Re j = 0; j < siz[x] && j <= k; ++j)
if (f[x][j])
for (Re t = 0; t < siz[u] && j + t <= k; ++t)
if (f[u][t])
{
int v = 1ll * f[x][j] * f[u][t] % p;
g[j + t] = inc(g[j + t], v), ans[j + t] = inc(ans[j + t], v);
g[j + t + 1] = inc(g[j + t + 1], v), ans[j + t + 1] = inc(ans[j + t + 1], v);
}
siz[x] += siz[u];
for (Re j = 0; j <= k && j < siz[x]; ++j) f[x][j] = g[j];
}
}
int main()
{
n = read(), k = read(), S[0][0] = 1;
for (Re i = 1; i < n; ++i)
{
int u = read(), v = read();
add(u, v), add(v, u);
}
for (Re i = 1; i <= k; ++i)
for (Re j = 1; j <= i; ++j) S[i][j] = (1ll * S[i - 1][j] * j + S[i - 1][j - 1]) % p;
dfs(1, 0);
for (Re i = 1, j = 1; i <= k; j = 1ll * j * (++i) % p) s = (1ll * S[k][i] * j % p * ans[i] + s) % p;
printf("%d", s);
return 0;
}