[题解]P8935 [JRKSJ R7] 茎

思路

首先思考 \(x = 1\) 的做法。有一个很简单的 dp,定义 \(dp_{u,i}\) 表示在 \(u\) 子树中操作了 \(i\) 次的方案数,需要注意的是这里并不强制要把 \(u\) 子树全部切掉。

考虑转移。因为 \(u\) 被切掉过后其任何子树内都将无法进行操作,所以如果要操作 \(u\) 则这一步必定是最后一步。于是先合并子树信息,有:

\[dp'_{u,i + j} \leftarrow \binom{i + j}{j}dp_{u,i} \times dp_{v,j} \]

然后再考虑 \(u\) 是否被操作,有:

\[dp'_{u,i} = dp_{u,i - 1} + dp_{u,i} \]

接下来考虑 \(x \neq 1\) 的做法。注意到从 \(1 \leadsto x\) 路径上的任意选一个点操作都会让 \(x\) 被切掉,不妨把这条链单独拎出来进行 dp。定义 \(f_{u,i}\) 表示在 \(1 \leadsto x\) 这条路径上,只选择在 \(1 \leadsto u\) 上的点上进行操作,并将整棵树切完,同时切下 \(u\) 之前操作了 \(i\) 步的方案数。考虑 \((u,v)\) 边的转移:

  • \(v\) 不在 \(1 \leadsto x\) 路径上,枚举 \(v\) 子树操作了 \(j\) 次,有:\(f'_{u,i} \leftarrow \binom{i}{j}dp_{v,j} \times f_{v,i - j}\)
  • \(v\)\(1 \leadsto x\) 路径上,讨论操不操作 \(v\)(为了方便转移这里钦定操作 \(v\) 算在切下 \(u\) 之后):
    • 不操作:\(f'_{v,i} \leftarrow f_{u,i}\)
    • 操作:因为操作 \(v\) 之前 \(u\) 一定没被操作,因此可以直接转移所有 \(j \geq i\)\(f_{u,j}\),有 \(f'_{v,i} \leftarrow \sum_{j \geq i} f_{u,j}\)

因为第 \(k\) 次操作必须操作 \(u\),所以在转移的时候不能转移不选择 \(u\) 的情况。转移用后缀和优化可以做到 \(\Theta(n^2)\)

Code

#include <bits/stdc++.h>
#define re register
#define int long long
#define Add(a,b) (((a) + (b)) % mod)
#define Mul(a,b) ((a) * (b) % mod)
#define chAdd(a,b) (a = Add(a,b))
#define chMul(a,b) (a = Mul(a,b))

using namespace std;

const int N = 510;
const int mod = 1e9 + 7;
int n,k,p;
int fp[N],sz[N],son[N],tmp[N];
int fac[N],infac[N],dp[N][N],f[N][N];
vector<int> g[N];

inline int read(){
    int r = 0,w = 1;
    char c = getchar();
    while (c < '0' || c > '9'){
        if (c == '-') w = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9'){
        r = (r << 3) + (r << 1) + (c ^ 48);
        c = getchar();
    }
    return r * w;
}

inline int qmi(int a,int b){
    int res = 1;
    while (b){
        if (b & 1) chMul(res,a);
        chMul(a,a); b >>= 1;
    } return res;
}

inline void init(int n){
    fac[0] = 1;
    for (re int i = 1;i <= n;i++) fac[i] = Mul(fac[i - 1],i);
    infac[n] = qmi(fac[n],mod - 2);
    for (re int i = n - 1;~i;i--) infac[i] = Mul(infac[i + 1],i + 1);
}

inline int C(int n,int m){
    if (n < m) return 0;
    else return Mul(fac[n],Mul(infac[m],infac[n - m]));
}

inline void dfs(int u,int fa){
    fp[u] = fa,dp[u][0] = 1;
    if (u == p) son[fa] = u;
    for (int v:g[u]){
        if (v == fa) continue;
        dfs(v,u);
        if (son[v]) son[u] = v;
        for (re int i = 0;i <= sz[u] + sz[v];i++) tmp[i] = 0;
        for (re int i = 0;i <= sz[u];i++){
            for (re int j = 0;j <= sz[v];j++) chAdd(tmp[i + j],Mul(C(i + j,j),Mul(dp[u][i],dp[v][j])));
        } sz[u] += sz[v];
        for (re int i = 0;i <= sz[u];i++) dp[u][i] = tmp[i];
    } sz[u]++;
    for (re int i = sz[u];i;i--) chAdd(dp[u][i],dp[u][i - 1]);
}

signed main(){
    n = read(),k = read(),p = read();
    init(n);
    for (re int i = 1,a,b;i < n;i++){
        a = read(),b = read();
        g[a].push_back(b);
        g[b].push_back(a);
    } dfs(1,0);
    int u = 1,num = 0;
    f[1][0] = 1;
    while (u){
        num++;
        for (re int i = 0;i <= num + 2;i++) tmp[i] = 0;
        for (re int i = num;~i;i--) tmp[i] = Add(tmp[i + 1],f[fp[u]][i]);
        if (u != 1){
            for (re int i = 0;i <= num;i++){
                f[u][i] = tmp[i];
                if (u != p) chAdd(f[u][i],f[fp[u]][i]);
            }
        }
        for (int v:g[u]){
            if (v == fp[u] || v == son[u]) continue;
            for (re int i = 0;i <= num + sz[v];i++) tmp[i] = 0;
            for (re int i = 0;i <= num;i++){
                for (re int j = 0;j <= sz[v];j++) chAdd(tmp[i + j],Mul(C(i + j,j),Mul(dp[v][j],f[u][i])));
            } num += sz[v];
            for (re int i = 0;i <= num;i++) f[u][i] = tmp[i];
        } u = son[u];
    } printf("%lld",f[p][k - 1]);
    return 0;
}
posted @ 2026-01-22 10:55  WBIKPS  阅读(0)  评论(0)    收藏  举报