Loading

题解:CCPC2023 秦皇岛 M. Inverted

题面

题意

给定一棵 \(n\) 个节点的树,每次选择一个节点复制,每次操作后求生成树个数。(\(n \le 5000\)

做法

首先我们考虑一个树操作过后会变成什么。

这是原图:

进行第 \(1\) 次操作后:

同理,进行前 \(3\) 次操作后:

我们可以发现,操作后即为将原树复制一遍,黑边为原树上的边(及复制产生的),红边为新产生的。

我们将红边的定义变为连接两图的边,红点为连接两图的点(即未被操作过的点),如下图:

这样稍微重新排布一下,我们可以发现操作后的图是完全对称的!

我当时的思路:这实际上是一个分层图,我们在空间中来看,原图中的黑点是第一层,红点是第二层,复制后的黑点是第三层,第一层和第三层通过若干的第二层红色的边和点相连。整个图是完全上下关于第二层对称的。

现在,我们要删去若干条边,使得其中没有环,即该图的一个生成树。

为了思考方便,我们先假设复制出的图联通。

我们先考虑最简单的情况,只删红边:

  • 对于只删红边的情况,上下有两个黑色的联通块,中间有且仅有一条红边将两图连接在一起。多了会成环,少了不联通。

现在思考如何删去黑边:我们假设删去第三层(即复制的图)中的黑边。

  • 删去一条黑边后,图断为两个联通块,故这两个联通块需要分别有且仅有一条红边与下面的图相连

删去更多的黑边同理,复制的图不联通的情况也理解成若干个联通块。

考虑完如何删边,接下来我们想如何计数。

在之前我们把图看成三层立体时,上下是完全关于中间对称的,于是我们可以想到把三层图合并成一个,这样方便我们后续处理。合并时保留原先红点黑点红边黑边的定义,合并后如图:

我们将上述删边的条件整合转换到新的合并图上,得到删边的总条件:

  • 删去若干条红/黑边,使得每一个黑点构成的联通块都有且仅有通过一条红边与其他的红点相连

问题即为求这样删边的总方案数。

需要注意的是,我们是把一个多层图“压”成的一个二维图,每条合并图上的边实际上都代表上下两条对称的边,故删去一条黑边只是删掉其中一层图中的一条黑边,删去一条红边也只是断开一层图与红点的连接,删去一条边有删掉上下两层图的两种情况,需要将方案数乘 \(2\)

现在考虑 DP。

定义 \(dp_{u, 0/1}\)\(u\) 号点所在的联通块是否有红边相连的断边的方案数。考虑使用合并子树的方法转移:

  • \(dp_{u, 0} = (dp_{u, 0} \times dp_{v, 0} + 2 \times dp_{u, 0} \times dp_{v, 1})\)

  • \(dp_{u, 1} = (dp_{u, 1} \times dp_{v, 0} + dp_{u, 0} \times dp_{v, 1} + 2 \times dp_{u, 1} \times dp_{v, 1})\)

需要注意一个细节,DP 的边界条件:\(u\) 节点是红点时,将其初始化为 \(dp_{u, 0} = 0, dp_{u, 1} = 1\) ,可以理解为其无没有与红点相连的方案数与有 \(1\) 个与一个红点相连的方案数。

最后枚举每个联通块,分别 DP,再把其 \(dp_{u, 1}\) 的值加起来即可。

别忘了高精度和取模。

代码部分:

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int mod = 998244353;
int n, dp[5010][2], ans;
bool blk[5010], vis[5010];
vector<int> e[5010];
void dfs(int u, int fa = 0) {
  if (!blk[u]) return;
  vis[u] = 1;
  dp[u][0] = 1;
  int cnt = 0, s = 0;
  for (int v : e[u]) {
    if (v == fa) continue;
    dfs(v, u);
    if (!blk[v]) cnt++;
  }
  for (int v : e[u]) {
    if (v == fa) continue;
    if (!blk[v]) dp[v][1] = 1;
    dp[u][1] = (dp[u][1]*dp[v][0] + dp[u][0]*dp[v][1] + 2ll*dp[u][1]*dp[v][1]) % mod;
    dp[u][0] = (dp[u][0]*dp[v][0] + 2ll*dp[u][0]*dp[v][1]) % mod;
  }
  return;
}
signed main() {
  cin >> n;
  for (int i = 1; i < n; i++) {
    int u, v;
    cin >> u >> v;
    e[u].push_back(v);
    e[v].push_back(u);
  }
  for (int var = 1; var < n; var++) {
    int x;
    cin >> x;
    ans = 1;
    blk[x] = 1;
    for (int i = 1; i <= n; i++)
      vis[i] = 0, dp[i][0] = dp[i][1] = 0;
    for (int i = 1; i <= n; i++) {
      if (blk[i] && !vis[i]) {
        dfs(i);
        (ans *= dp[i][1]) %= mod;
      }
    }
    cout << ans << endl;
  }
  return 0;
}
posted @ 2025-10-07 15:05  UserJCY  阅读(20)  评论(0)    收藏  举报
Title