[题解]P8935 [JRKSJ R7] 茎
思路
首先思考 \(x = 1\) 的做法。有一个很简单的 dp,定义 \(dp_{u,i}\) 表示在 \(u\) 子树中操作了 \(i\) 次的方案数,需要注意的是这里并不强制要把 \(u\) 子树全部切掉。
考虑转移。因为 \(u\) 被切掉过后其任何子树内都将无法进行操作,所以如果要操作 \(u\) 则这一步必定是最后一步。于是先合并子树信息,有:
\[dp'_{u,i + j} \leftarrow \binom{i + j}{j}dp_{u,i} \times dp_{v,j}
\]
然后再考虑 \(u\) 是否被操作,有:
\[dp'_{u,i} = dp_{u,i - 1} + dp_{u,i}
\]
接下来考虑 \(x \neq 1\) 的做法。注意到从 \(1 \leadsto x\) 路径上的任意选一个点操作都会让 \(x\) 被切掉,不妨把这条链单独拎出来进行 dp。定义 \(f_{u,i}\) 表示在 \(1 \leadsto x\) 这条路径上,只选择在 \(1 \leadsto u\) 上的点上进行操作,并将整棵树切完,同时切下 \(u\) 之前操作了 \(i\) 步的方案数。考虑 \((u,v)\) 边的转移:
- \(v\) 不在 \(1 \leadsto x\) 路径上,枚举 \(v\) 子树操作了 \(j\) 次,有:\(f'_{u,i} \leftarrow \binom{i}{j}dp_{v,j} \times f_{v,i - j}\)。
- \(v\) 在 \(1 \leadsto x\) 路径上,讨论操不操作 \(v\)(为了方便转移这里钦定操作 \(v\) 算在切下 \(u\) 之后):
- 不操作:\(f'_{v,i} \leftarrow f_{u,i}\)。
- 操作:因为操作 \(v\) 之前 \(u\) 一定没被操作,因此可以直接转移所有 \(j \geq i\) 的 \(f_{u,j}\),有 \(f'_{v,i} \leftarrow \sum_{j \geq i} f_{u,j}\)。
因为第 \(k\) 次操作必须操作 \(u\),所以在转移的时候不能转移不选择 \(u\) 的情况。转移用后缀和优化可以做到 \(\Theta(n^2)\)。
Code
#include <bits/stdc++.h>
#define re register
#define int long long
#define Add(a,b) (((a) + (b)) % mod)
#define Mul(a,b) ((a) * (b) % mod)
#define chAdd(a,b) (a = Add(a,b))
#define chMul(a,b) (a = Mul(a,b))
using namespace std;
const int N = 510;
const int mod = 1e9 + 7;
int n,k,p;
int fp[N],sz[N],son[N],tmp[N];
int fac[N],infac[N],dp[N][N],f[N][N];
vector<int> g[N];
inline int read(){
int r = 0,w = 1;
char c = getchar();
while (c < '0' || c > '9'){
if (c == '-') w = -1;
c = getchar();
}
while (c >= '0' && c <= '9'){
r = (r << 3) + (r << 1) + (c ^ 48);
c = getchar();
}
return r * w;
}
inline int qmi(int a,int b){
int res = 1;
while (b){
if (b & 1) chMul(res,a);
chMul(a,a); b >>= 1;
} return res;
}
inline void init(int n){
fac[0] = 1;
for (re int i = 1;i <= n;i++) fac[i] = Mul(fac[i - 1],i);
infac[n] = qmi(fac[n],mod - 2);
for (re int i = n - 1;~i;i--) infac[i] = Mul(infac[i + 1],i + 1);
}
inline int C(int n,int m){
if (n < m) return 0;
else return Mul(fac[n],Mul(infac[m],infac[n - m]));
}
inline void dfs(int u,int fa){
fp[u] = fa,dp[u][0] = 1;
if (u == p) son[fa] = u;
for (int v:g[u]){
if (v == fa) continue;
dfs(v,u);
if (son[v]) son[u] = v;
for (re int i = 0;i <= sz[u] + sz[v];i++) tmp[i] = 0;
for (re int i = 0;i <= sz[u];i++){
for (re int j = 0;j <= sz[v];j++) chAdd(tmp[i + j],Mul(C(i + j,j),Mul(dp[u][i],dp[v][j])));
} sz[u] += sz[v];
for (re int i = 0;i <= sz[u];i++) dp[u][i] = tmp[i];
} sz[u]++;
for (re int i = sz[u];i;i--) chAdd(dp[u][i],dp[u][i - 1]);
}
signed main(){
n = read(),k = read(),p = read();
init(n);
for (re int i = 1,a,b;i < n;i++){
a = read(),b = read();
g[a].push_back(b);
g[b].push_back(a);
} dfs(1,0);
int u = 1,num = 0;
f[1][0] = 1;
while (u){
num++;
for (re int i = 0;i <= num + 2;i++) tmp[i] = 0;
for (re int i = num;~i;i--) tmp[i] = Add(tmp[i + 1],f[fp[u]][i]);
if (u != 1){
for (re int i = 0;i <= num;i++){
f[u][i] = tmp[i];
if (u != p) chAdd(f[u][i],f[fp[u]][i]);
}
}
for (int v:g[u]){
if (v == fp[u] || v == son[u]) continue;
for (re int i = 0;i <= num + sz[v];i++) tmp[i] = 0;
for (re int i = 0;i <= num;i++){
for (re int j = 0;j <= sz[v];j++) chAdd(tmp[i + j],Mul(C(i + j,j),Mul(dp[v][j],f[u][i])));
} num += sz[v];
for (re int i = 0;i <= num;i++) f[u][i] = tmp[i];
} u = son[u];
} printf("%lld",f[p][k - 1]);
return 0;
}

浙公网安备 33010602011771号