QOJ4811 Be Careful

考虑 \(dp_{i,v}\) 表示以 \(i\) 为根的子树,\(i\) 的权值为 \(v\) 的方案数。

但是这个 dp 不好转移,因为 MEX 不是很好表示。但是如果只考虑子节点中的非叶子节点,那么 MEX 是 \(O(\sqrt n)\) 的,这个可以接受。

考虑阈值分治,把 \(deg\le B\)\(deg>B\) 的分开考虑。特殊地,叶子节点也放在 \(deg>B\) 考虑。其中 \(deg\) 为子节点个数。

对于 \(deg\le B\) 的部分,因为其权值也 \(\le B\),考虑直接设 \(dp1_{i,msk}\) 表示考虑前 \(i\)\(deg\le B\) 的子树,每种权值出现状态为 \(msk\)。这一部分转移是 \(O(2^B\times \operatorname{poly}(n))\) 的。

对于 \(deg>B\) 的部分,显然只有 \(\le \dfrac{n}{B}\) 棵子树,叶子节点等价,用一维记录使用情况即可。即 \(dp2_{v,c,msk}\) 表示权值填满了 \(0\sim v\),有 \(c\) 种权值使用了叶子,\(deg>B\) 的子树使用状态为 \(msk\) 的方案数。

考虑这个的转移,不妨设填完 \(v+1\) 的子节点占用状态为 \(S\),填之前状态为 \(T\),那么贡献的系数就是 \(\prod\limits_{i\in S\backslash T}dp_{i,v+1}\)。这个可以用高维前缀和优化。注意这个不能直接用两个子集的积除,因为这样分母可能为 \(0\)。一个解决方法是,对第 \(i\) 维前缀和时,把 \(dp_{i,v+1}\) 乘上。\(c\) 能转移到 \(c\)\(c+1\),倒着枚举 \(c\),直接 \(dp2_{c+1,msk}\leftarrow dp2_{c,msk}\) 即可。

注意到第二个 dp 需要知道 \(deg\le B\) 的子树的权值状态,所以需要枚举第一个 dp 的 \(msk\),对每个 \(msk\) 跑一遍第二个 dp。同时需要减掉 \(v\) 中没有选子树或叶子 且 第一个 dp 也没有占用 \(v\) 的情况,这个就是 \(dp2_v\) 减去对应位置 \(dp2_{v-1}\) 的值。

考虑怎么求答案,不妨求钦定 \(0\sim v\) 都出现过的方案数,最后再差分。枚举第二个 dp 的 \(c,msk\),对于叶子节点,可以枚举已经被选中的个数,再分配到 \(c\) 类中,这个可以用斯特林数算。其它的叶子每个的方案是 \(n-v\),乘起来即可。对于剩余的 \(deg>B\) 的子树,权值选 \(v+1\sim n\) 相当于 dp 的后缀和,那么总方案数就是 dp 的后缀和的积。

\(C\)\(deg>B\) 的子树数量,那么总复杂度就是 \(O(2^{B+C}\times\operatorname{poly}(n))\) 的。

看起来平衡一下指数是 \(2\sqrt n\),实际上我们可以分析到 \(\sqrt{2n}\)

不妨假设我们当前的 \(B+C\) 是最小的。那么有:

  • 对于任意 \(k\),子节点中 \(deg\in [B-k+1,B]\) 的不少于 \(k\) 个。

  • 对于任意 \(k\),子节点中 \(deg\in [B+1,B+k]\) 的不多于 \(k\) 个。

这里的子节点都不包含叶子。证明调整即可。

也就是说,子节点的 \(deg\) 之和不小于 \(1+2+\cdots+(B+C)\),但是 \(deg\) 的上界是 \(n-1\),因此 \(B+C\) 的最小值也不超过 \(\sqrt{2n}\)

于是总复杂度 \(O(2^{\sqrt{2n}}\times \operatorname{poly}(n))\),实际上完全卡不满,可以通过。

#include <bits/stdc++.h>
#define ALL(x) begin(x), end(x)
using namespace std;
void file() {
  freopen("1.in", "r", stdin);
  freopen("1.out", "w", stdout);
}
using ll = long long;

const int kMod = 998244353;
void Add(int& x, int y) { ((x += y) >= kMod) && (x -= kMod); }
void Sub(int& x, int y) { ((x -= y) < 0) && (x += kMod); }
int Sum(int x, int y) { return Add(x, y), x; }
int Dif(int x, int y) { return Sub(x, y), x; }
int Pow(int x, int y) {
  int b = x, r = 1;
  for(; y; b = (ll)b * b % kMod, y /= 2) {
    if(y & 1) r = (ll)r * b % kMod;
  }
  return r;
}
int Inv(int x) { return Pow(x, kMod - 2); }

const int kN = 205;
int n;
array<int, kN> deg;
array<array<int, kN>, kN> C, S, dp, suf;
array<vector<int>, kN> g;

void Init() {
  for(int i = 0; i <= n; i++) {
    C[i][0] = 1;
    for(int j = 1; j <= i; j++) {
      C[i][j] = Sum(C[i - 1][j], C[i - 1][j - 1]);
    }
  }
  S[0][0] = 1;
  for(int i = 1; i <= n; i++) {
    for(int j = 1; j <= i; j++) {
      S[i][j] = ((ll)S[i - 1][j] * j + S[i - 1][j - 1]) % kMod;
    }
  }
  for(int i = 0; i <= n; i++) {
    for(int j = 0; j <= n; j++) {
      if(!S[i][j]) continue;
      for(int k = 2; k <= j; k++) {
        S[i][j] = (ll)S[i][j] * k % kMod;
      }
    }
  }
}

int Eval(int x, int fa, int B) {
  int ans = B;
  for(int to : g[x]) {
    if(to != fa) ans += (deg[to] > B);
  }
  return ans;
}

void Dfs(int x, int fa) {
  deg[x] = g[x].size() - !!fa;
  if(!deg[x]) return ;
  for(int to : g[x]) {
    if(to != fa) Dfs(to, x);
  }

  int B = -1, cnt = n + 1, leaf = 0;
  for(int b = 0; b <= n; b++) {
    int val = Eval(x, fa, b);
    if(cnt > val) cnt = val, B = b;
  }
  vector<int> small, large;
  for(int to : g[x]) {
    if(to != fa) {
      if(!deg[to]) leaf++;
      else {
        if(deg[to] <= B) small.push_back(to);
        else large.push_back(to);
      }
    }
  }
  sort(ALL(small), [&](int x, int y) -> bool { return deg[x] < deg[y]; });
  sort(ALL(large), [&](int x, int y) -> bool { return deg[x] < deg[y]; });

  vector<int> dp1 {1};
  
  auto Dp1 = [&]() -> void {
    vector<int> old {1};
    for(int i : small) {
      dp1.assign(1 << deg[i] + 1, 0);
      for(int msk = 0; msk < old.size(); msk++) {
        for(int v = 0; v <= deg[i]; v++) {
          Add(dp1[msk | (1 << v)], (ll)old[msk] * dp[i][v] % kMod);
        }
      }
      old = dp1;
    }
  };

  auto Dp2 = [&](int msk, int coe) -> void {
    int siz = large.size();
    vector<vector<int>> old, dp2;
    old.resize(leaf + 1, vector<int> (1 << siz, 0)), dp2 = old;
    old[0][0] = dp2[0][0] = 1;

    vector<int> pw (leaf + 1, 0);
    vector<int> prod (1 << siz, 0);

    auto GetAns = [&](int v) -> void {
      prod[0] = 1;
      for(int i = 0; i < siz; i++) prod[1 << i] = suf[large[i]][v];
      for(int msk = 1; msk < (1 << siz); msk++) {
        int lb = msk & -msk;
        prod[msk] = (ll)prod[msk ^ lb] * prod[lb] % kMod;
      }
      pw[0] = 1;
      for(int i = 1; i <= leaf; i++) {
        pw[i] = (ll)pw[i - 1] * (n - v + 1) % kMod;
      }
      int all = (1 << siz) - 1;
      for(int c = 0; c <= min(v, leaf); c++) {
        int coef = 0;
        for(int i = c; i <= leaf; i++) {
          Add(coef, (ll)S[i][c] * C[leaf][i] % kMod * pw[leaf - i] % kMod);
        }
        coef = (ll)coef * coe % kMod;
        for(int msk = 0; msk < (1 << siz); msk++) {
          Add(suf[x][v], (ll)prod[all ^ msk] * coef % kMod * old[c][msk] % kMod);
        }
      }
    };

    for(int v = 0; v <= deg[x]; v++, old = dp2) {
      GetAns(v);
      for(int c = 0; c <= min(v, leaf); c++) {
        for(int i = 0; i < siz; i++) {
          for(int msk = 0; msk < (1 << siz); msk++) {
            if(msk & (1 << i)) {
              Add(dp2[c][msk], (ll)dp2[c][msk ^ (1 << i)] * dp[large[i]][v] % kMod);
            }
          }
        }
      }
      for(int c = min(v + 1, leaf); c; c--) {
        for(int msk = 0; msk < (1 << siz); msk++) {
          Add(dp2[c][msk], dp2[c - 1][msk]);
        }
      }
      if((v > __lg(msk + 1)) || !(msk & (1 << v))) {
        for(int c = 0; c <= min(v, leaf); c++) {
          for(int msk = 0; msk < (1 << siz); msk++) {
            Sub(dp2[c][msk], old[c][msk]);
          }
        }
      }
    }
  };

  Dp1();
  for(int msk = 0; msk < dp1.size(); msk++) Dp2(msk, dp1[msk]);
  for(int i = 0; i <= deg[x]; i++) {
    dp[x][i] = Dif(suf[x][i], suf[x][i + 1]);
  }
}

int main() {
  // file();
  ios::sync_with_stdio(0), cin.tie(0);
  cin >> n, Init();
  for(int i = 1, u, v; i < n; i++) {
    cin >> u >> v;
    g[u].push_back(v);
    g[v].push_back(u);
  }
  Dfs(1, 0);
  for(int i = 0; i <= n; i++) cout << dp[1][i] << "\n";
  return 0;
}
posted @ 2024-12-31 09:08  CJzdc  阅读(60)  评论(0)    收藏  举报