P4426 毒瘤笔记

前置知识点:虚树,dp。

题意

给定一个 \(n\) 个点 \(m\) 条边的无向简单联通图,满足 \(n - 1 \le m \le n + 10\)。求图的独立集个数,对 \(998244353\) 取模。

题解

首先,注意到 \(m \le n + 10\),也就是说非树边只有最多 \(11\) 条。将这些非树边连接的 \(s=22\) 个点(下面称为关键点)找出来,接着 \(2^s\) 枚举每个关键点的状态,最后对整棵树树形 dp 就可以在 \(\mathcal{O}(n2^s)\) 复杂度下解决这个问题,可以得到 70+pts 的好成绩。

于是沿着这一条思路,可以想办法优化。我们先将朴素树形 dp 给推出,设 \(f_{x, 0 / 1}\) 表示 \(x\) 为根的子树中 \(x\) 选或不选的方案数。那么有:

\[\begin{cases} f_{x, 0} = \prod\limits_{v \in x.\text{son}} (f_{v, 0} + f_{v, 1})\\ f_{x, 1} = \prod\limits_{v \in x.\text{son}} f_{v, 0} \end{cases} \]

接着我们建出关键点的虚树,对于虚树上的一条边 \((x, v)\),我们发现 \(f_{v, 0 / 1}\)\(x\) 的贡献竟然可以这么表示 \(f_{x, 0 / 1} *= (k_0 \cdot f_{v, 0} + k_1 \cdot f_{v, 1})\)。并且由于在枚举关键点状态时,虚树的状态不会改变,所以 \(k_0, k_1\) 是个定值!

这样我们就可以在 \(\mathcal{O}(2^s)\) 枚举前,预处理出来系数,接着在虚树上 \(\mathcal{O}(s)\) dp 就行了。总复杂度 \(\mathcal{O}(s2^s)\)

下面详细讲一下系数是怎么推出来的。对于虚树上的一条边 \((x, y)\),在原树上从 \(y\) 一步一步跳到 \(x\),这样复杂度显然是 \(\mathcal{O}(n)\)。记 \(p_i\) 表示 \(v\)\(i\) 级祖先。\(k\) 表示系数,有:

\[\begin{cases} f_{p_i, 0} = f_{p_i, 0}' \times (k_{p_i, 0, 0} \times f_{v, 0} + k_{p_i, 0, 1} \times f_{v, 1}) \\ f_{p_i, 1} = f_{p_i, 1}' \times (k_{p_i, 1, 0} \times f_{v, 0} + k_{p_i, 1, 1} \times f_{v, 1}) \end{cases} \]

所以 \(k_{p_i, 0/1, 0/1}\)\(k_{p_{i + 1}, 0/1, 0/1}\) 的关系式只需暴力展开得到:

\[\begin{cases} \begin{aligned} f_{p_{i + 1}, 0} &= f_{p_{i + 1}, 0}' \times (f_{p_i, 0} + f_{p_i, 1})\\ &=f_{p_{i + 1}, 0}' \times \left[f_{p_i, 0}' \times (k_{p_i, 0, 0} \times f_{v, 0} + k_{p_i, 0, 1} \times f_{v, 1}) + f_{p_i, 1}' \times (k_{p_i, 1, 0} \times f_{v, 0} + k_{p_i, 1, 1} \times f_{v, 1})\right]\\ &=f_{p_{i + 1}, 0}'\times\left[(f_{p_i, 0}' \times k_{p_i, 0, 0} + f_{p_i, 1}' \times k_{p_i, 1, 0})\times f_{v, 0} + (f_{p_i, 0}' \times k_{p_i, 0, 1} + f_{p_i, 1}' \times k_{p_i, 1, 1}) \times f_{v, 1}\right] \\ &\Rightarrow \begin{cases} k_{p_{i + 1}, 0, 0} = f_{p_i, 0}' \times k_{p_i, 0, 0} + f_{p_i, 1}' \times k_{p_i, 1, 0} \\ k_{p_{i + 1}, 0, 1} = f_{p_i, 0}' \times k_{p_i, 0, 1} + f_{p_i, 1}' \times k_{p_i, 1, 1} \end{cases} \end{aligned} \\ \begin{aligned} f_{p_{i + 1}, 1} &= f_{p_{i + 1}, 1}' \times f_{p_i, 0} \\ &= f_{p_{i + 1}, 1}' \times \left[ f_{p_i, 0}' \times (k_{p_i, 0, 0} \times f_{v, 0} + k_{p_i, 0, 1} \times f_{v, 1}) \right] \\ &= f_{p_{i + 1}, 1}' \times \left[(f_{p_i, 0}' \times k_{p_i, 0, 0}) \times f_{v, 0} + (f_{p_i, 0}' \times k_{p_i, 0, 1}) \times f_{v, 1}\right]\\ &\Rightarrow \begin{cases} k_{p_{i + 1}, 1, 0} = f_{p_i, 0}' \times k_{p_i, 0, 0} \\ k_{p_{i + 1}, 1, 1} = f_{p_i, 0}' \times k_{p_i, 0, 1} \end{cases} \end{aligned} \end{cases} \]

整理得:

\[\begin{cases} k_{p_{i + 1}, 0, 0} = f_{p_i, 0}' \times k_{p_i, 0, 0} + f_{p_i, 1}' \times k_{p_i, 1, 0} \\ k_{p_{i + 1}, 0, 1} = f_{p_i, 0}' \times k_{p_i, 0, 1} + f_{p_i, 1}' \times k_{p_i, 1, 1} \\ k_{p_{i + 1}, 1, 0} = f_{p_i, 0}' \times k_{p_i, 0, 0} \\ k_{p_{i + 1}, 1, 1} = f_{p_i, 0}' \times k_{p_i, 0, 1} \end{cases} \]

当然直接记下 \(k_{x, 0/1, 0/1}\) 是不行的,因为 \(x\) 可能有多个儿子。所以直接开一个数组记录这一条虚边的系数就行了。

具体细节看代码实现:

代码
#include <bits/stdc++.h>
using namespace std;
template <class T>
T power(T a, long long b) {
  T res = 1;
  for (; b; b >>= 1, a *= a) {
    if (b & 1)
      res *= a;
  } return res;
}
template <long long mod>
class ModLL {
  public:
    long long n;
    static long long Mod;
    constexpr ModLL() : n{} {}
    constexpr ModLL(long long x) : n(norm(x % getmod())) {}
    constexpr long long norm(long long x) {
      if (x >= getmod()) x %= getmod();
      if (x <= -getmod()) x %= getmod();
      if (x < 0) x += getmod();
      return x;
    }
    constexpr long long getmod() {return (mod > 0 ? mod : Mod);}
    explicit constexpr operator long long() const {return n;}
    constexpr ModLL operator -() const {ModLL a; a.n = norm(getmod() - n); return a;}
    constexpr ModLL inv() {assert(n != 0); return power((*this), getmod() - 2);}
    constexpr ModLL &operator += (ModLL w) & {n = norm( n + w.n); return (*this);}
    constexpr ModLL &operator -= (ModLL w) & {n = norm( n - w.n); return (*this);}
    constexpr ModLL &operator *= (ModLL w) & {n = norm( 1LL * n * w.n % getmod()); return (*this);}
    constexpr ModLL &operator /= (ModLL w) & {return (*this) *= w.inv();}
    friend constexpr ModLL operator + (ModLL a, ModLL b) {ModLL res = a; res += b; return res;}
    friend constexpr ModLL operator - (ModLL a, ModLL b) {ModLL res = a; res -= b; return res;}
    friend constexpr ModLL operator * (ModLL a, ModLL b) {ModLL res = a; res *= b; return res;}
    friend constexpr ModLL operator / (ModLL a, ModLL b) {ModLL res = a; res /= b; return res;}
    friend constexpr bool operator == (ModLL a, ModLL b) {return (a.n == b.n);}
    friend constexpr bool operator != (ModLL a, ModLL b) {return (a.n != b.n);}
    friend constexpr istream &operator >> (istream &is, ModLL &a) {
      long long x = 0; is >> x;
      a = ModLL(x); return is;
    }
    friend constexpr ostream &operator << (ostream &os, const ModLL &a) {return (os << (a.n));}
} ; 
template <int mod>
class ModInt {
  public:
    int n;
    static int Mod;
    constexpr ModInt() : n{} {}
    constexpr ModInt(int x) : n(norm(x % getmod())) {}
    constexpr int norm(int x) {
      if (x >= getmod()) x %= getmod();
      if (x <= -getmod()) x %= getmod();
      if (x < 0) x += getmod();
      return x;
    }
    constexpr static int getmod() {return (mod > 0 ? mod : Mod);}
    explicit constexpr operator int() const {return n;}
    constexpr ModInt operator -() const {ModInt a; a.n = norm(getmod() - n); return a;}
    constexpr ModInt inv() const {assert(n != 0); return power((*this), getmod() - 2);}
    constexpr ModInt &operator += (ModInt w) & {n = norm( n + w.n); return (*this);}
    constexpr ModInt &operator -= (ModInt w) & {n = norm( n - w.n); return (*this);}
    constexpr ModInt &operator *= (ModInt w) & {n = norm( 1LL * n * w.n % getmod()); return (*this);}
    constexpr ModInt &operator /= (ModInt w) & {return (*this) *= w.inv();}
    friend constexpr ModInt operator + (ModInt a, ModInt b) {ModInt res = a; res += b; return res;}
    friend constexpr ModInt operator - (ModInt a, ModInt b) {ModInt res = a; res -= b; return res;}
    friend constexpr ModInt operator * (ModInt a, ModInt b) {ModInt res = a; res *= b; return res;}
    friend constexpr ModInt operator / (ModInt a, ModInt b) {ModInt res = a; res /= b; return res;}
    friend constexpr bool operator == (ModInt a, ModInt b) {return (a.n == b.n);}
    friend constexpr bool operator != (ModInt a, ModInt b) {return (a.n != b.n);}
    friend constexpr istream &operator >> (istream &is, ModInt &a) {
      int x = 0; is >> x;
      a = ModInt(x); return is;
    }
    friend constexpr ostream &operator << (ostream &os, const ModInt &a) {return (os << (a.n));}
} ; 
template <>
long long ModLL <0> :: Mod = (long long)(1E18) + 9;
template <>
int ModInt <0> :: Mod = 998244353;
using P = ModInt <998244353>;
using i64 = long long;
const int N = 2E5 + 5;
int n, m, bs, h[N], len, ll, dfn[N], ccnt;
vector <pair <int, int>> vs;
int dep[N], S = 19, yf[N][20], a[N << 1];
P f[N][2], fp[N][2], g[N][2], re[N][2];
bool key[N];
vector <int> G[N], E[N], ks;
map <pair <int, int>, int> rid, dir;
void find(int x, int fa) {
  dfn[x] = ++ccnt;
  f[x][0] = f[x][1] = 1;
  for (auto v : G[x]) {
    if (v == fa) continue;
    if (dfn[v] && dfn[x] < dfn[v]) {
      rid[make_pair(x, v)] = rid[make_pair(v, x)] = ++bs;
      vs.emplace_back(x, v);
      key[x] = key[v] = 1;
      continue;
    } else if (dfn[v]) continue;
    dep[v] = dep[x] + 1; yf[v][0] = x;
    for (int i = 1; i <= S; ++i) 
      yf[v][i] = yf[yf[v][i - 1]][i - 1];
    find(v, x);
    f[x][0] *= f[v][0] + f[v][1];
    f[x][1] *= f[v][0];
  }
}
int glca(int u, int v) {
  if (dep[u] < dep[v]) swap(u, v);
  for (int i = S; ~i; --i) if (dep[u] - (1 << i) >= dep[v]) u = yf[u][i];
  if (u == v) return u;
  for (int i = S; ~i; --i) if (yf[u][i] != yf[v][i])
    u = yf[u][i], v = yf[v][i];
  return yf[u][0];
}
void conn(int x, int y) {E[x].emplace_back(y); E[y].emplace_back(x);}
void build() {
  for (int i = 1; i <= n; ++i) if (key[i])
    h[++len] = i;
  sort(h + 1, h + len + 1, [&](int x, int y) {return dfn[x] < dfn[y];});
  for (int i = 1; i <= len; ++i) a[++ll] = h[i];
  for (int i = 1; i < len; ++i) a[++ll] = glca(h[i], h[i + 1]);
  a[++ll] = 1;
  sort(a + 1, a + 1 + ll, [&](int x, int y) {return dfn[x] < dfn[y];});
  ll = unique(a + 1, a + 1 + ll) - a - 1;
  for (int i = 1; i < ll; ++i) {
    int lc = glca(a[i], a[i + 1]);
    conn(lc, a[i + 1]);
  }
}
P k[40][2][2];
void sx(int x, int v) {
  int d = dir[make_pair(x, v)], lt = x;
  x = yf[x][0];
  k[d][0][0] = k[d][0][1] = 1;
  k[d][1][0] = 1; k[d][1][1] = 0;
  P pre[2][2];  
  while (x != v) {
    for (int i : {0, 1}) for (int j : {0, 1})
      pre[i][j] = k[d][i][j];
    fp[x][0] = f[x][0] / (f[lt][0] + f[lt][1]);
    fp[x][1] = f[x][1] / f[lt][0];
    k[d][0][0] = fp[x][0] * pre[0][0] + fp[x][1] * pre[1][0];
    k[d][0][1] = fp[x][0] * pre[0][1] + fp[x][1] * pre[1][1];
    k[d][1][0] = fp[x][0] * pre[0][0];
    k[d][1][1] = fp[x][0] * pre[0][1];
    lt = x;
    x = yf[x][0];
  }
  g[v][0] /= (f[lt][0] + f[lt][1]);
  g[v][1] /= f[lt][0];
}
void xs(int x, int fa) {
  ks.emplace_back(x);
  for (auto v : E[x]) {
    if (v == fa) continue;
    int now = dir.size() / 2;
    dir[make_pair(v, x)] = dir[make_pair(x, v)] = now + 1;
    sx(v, x);
    xs(v, x);
  }
}
void DP(int x, int fa) {
  for (auto v : E[x]) {
    if (v == fa) continue;
    int d = dir[make_pair(v, x)];
    DP(v, x);
    re[x][0] *= k[d][0][0] * re[v][0] + k[d][0][1] * re[v][1];
    re[x][1] *= k[d][1][0] * re[v][0] + k[d][1][1] * re[v][1];
  }
}
signed main(void) {
  ios :: sync_with_stdio(false);
  cin.tie(0); cout.tie(0);
  cin >> n >> m;
  for (int i = 1; i <= m; ++i) {
    int u, v; cin >> u >> v;
    G[u].emplace_back(v);
    G[v].emplace_back(u);
  }
  dep[1] = 1; for (int i = 0; i <= S; ++i) yf[1][i] = 1;
  find(1, 0); build();
  for (int i = 1; i <= n; ++i) g[i][0] = f[i][0], g[i][1] = f[i][1];
  xs(1, 0);
  P ans = 0; 
  for (int i = 1; i <= n; ++i) re[i][0] = g[i][0], re[i][1] = g[i][1];
  for (int i = 0; i < (1 << bs); ++i) {
    for (auto x : ks) re[x][0] = g[x][0], re[x][1] = g[x][1];
    for (int j = 0; j < bs; ++j) {
      auto [x, y] = vs[j];
      if (i >> j & 1) re[x][1] = 0;
      else re[x][0] = 0, re[y][1] = 0;
    }
    DP(1, 0);
    ans += re[1][0] + re[1][1];
  } cout << ans << '\n';
  return 0;
}
posted @ 2024-01-31 20:49  CTHOOH  阅读(27)  评论(0)    收藏  举报