P4426 毒瘤笔记
前置知识点:虚树,dp。
题意
给定一个 \(n\) 个点 \(m\) 条边的无向简单联通图,满足 \(n - 1 \le m \le n + 10\)。求图的独立集个数,对 \(998244353\) 取模。
题解
首先,注意到 \(m \le n + 10\),也就是说非树边只有最多 \(11\) 条。将这些非树边连接的 \(s=22\) 个点(下面称为关键点)找出来,接着 \(2^s\) 枚举每个关键点的状态,最后对整棵树树形 dp 就可以在 \(\mathcal{O}(n2^s)\) 复杂度下解决这个问题,可以得到 70+pts 的好成绩。
于是沿着这一条思路,可以想办法优化。我们先将朴素树形 dp 给推出,设 \(f_{x, 0 / 1}\) 表示 \(x\) 为根的子树中 \(x\) 选或不选的方案数。那么有:
接着我们建出关键点的虚树,对于虚树上的一条边 \((x, v)\),我们发现 \(f_{v, 0 / 1}\) 对 \(x\) 的贡献竟然可以这么表示 \(f_{x, 0 / 1} *= (k_0 \cdot f_{v, 0} + k_1 \cdot f_{v, 1})\)。并且由于在枚举关键点状态时,虚树的状态不会改变,所以 \(k_0, k_1\) 是个定值!
这样我们就可以在 \(\mathcal{O}(2^s)\) 枚举前,预处理出来系数,接着在虚树上 \(\mathcal{O}(s)\) dp 就行了。总复杂度 \(\mathcal{O}(s2^s)\)。
下面详细讲一下系数是怎么推出来的。对于虚树上的一条边 \((x, y)\),在原树上从 \(y\) 一步一步跳到 \(x\),这样复杂度显然是 \(\mathcal{O}(n)\)。记 \(p_i\) 表示 \(v\) 的 \(i\) 级祖先。\(k\) 表示系数,有:
所以 \(k_{p_i, 0/1, 0/1}\) 和 \(k_{p_{i + 1}, 0/1, 0/1}\) 的关系式只需暴力展开得到:
整理得:
当然直接记下 \(k_{x, 0/1, 0/1}\) 是不行的,因为 \(x\) 可能有多个儿子。所以直接开一个数组记录这一条虚边的系数就行了。
具体细节看代码实现:
代码
#include <bits/stdc++.h>
using namespace std;
template <class T>
T power(T a, long long b) {
T res = 1;
for (; b; b >>= 1, a *= a) {
if (b & 1)
res *= a;
} return res;
}
template <long long mod>
class ModLL {
public:
long long n;
static long long Mod;
constexpr ModLL() : n{} {}
constexpr ModLL(long long x) : n(norm(x % getmod())) {}
constexpr long long norm(long long x) {
if (x >= getmod()) x %= getmod();
if (x <= -getmod()) x %= getmod();
if (x < 0) x += getmod();
return x;
}
constexpr long long getmod() {return (mod > 0 ? mod : Mod);}
explicit constexpr operator long long() const {return n;}
constexpr ModLL operator -() const {ModLL a; a.n = norm(getmod() - n); return a;}
constexpr ModLL inv() {assert(n != 0); return power((*this), getmod() - 2);}
constexpr ModLL &operator += (ModLL w) & {n = norm( n + w.n); return (*this);}
constexpr ModLL &operator -= (ModLL w) & {n = norm( n - w.n); return (*this);}
constexpr ModLL &operator *= (ModLL w) & {n = norm( 1LL * n * w.n % getmod()); return (*this);}
constexpr ModLL &operator /= (ModLL w) & {return (*this) *= w.inv();}
friend constexpr ModLL operator + (ModLL a, ModLL b) {ModLL res = a; res += b; return res;}
friend constexpr ModLL operator - (ModLL a, ModLL b) {ModLL res = a; res -= b; return res;}
friend constexpr ModLL operator * (ModLL a, ModLL b) {ModLL res = a; res *= b; return res;}
friend constexpr ModLL operator / (ModLL a, ModLL b) {ModLL res = a; res /= b; return res;}
friend constexpr bool operator == (ModLL a, ModLL b) {return (a.n == b.n);}
friend constexpr bool operator != (ModLL a, ModLL b) {return (a.n != b.n);}
friend constexpr istream &operator >> (istream &is, ModLL &a) {
long long x = 0; is >> x;
a = ModLL(x); return is;
}
friend constexpr ostream &operator << (ostream &os, const ModLL &a) {return (os << (a.n));}
} ;
template <int mod>
class ModInt {
public:
int n;
static int Mod;
constexpr ModInt() : n{} {}
constexpr ModInt(int x) : n(norm(x % getmod())) {}
constexpr int norm(int x) {
if (x >= getmod()) x %= getmod();
if (x <= -getmod()) x %= getmod();
if (x < 0) x += getmod();
return x;
}
constexpr static int getmod() {return (mod > 0 ? mod : Mod);}
explicit constexpr operator int() const {return n;}
constexpr ModInt operator -() const {ModInt a; a.n = norm(getmod() - n); return a;}
constexpr ModInt inv() const {assert(n != 0); return power((*this), getmod() - 2);}
constexpr ModInt &operator += (ModInt w) & {n = norm( n + w.n); return (*this);}
constexpr ModInt &operator -= (ModInt w) & {n = norm( n - w.n); return (*this);}
constexpr ModInt &operator *= (ModInt w) & {n = norm( 1LL * n * w.n % getmod()); return (*this);}
constexpr ModInt &operator /= (ModInt w) & {return (*this) *= w.inv();}
friend constexpr ModInt operator + (ModInt a, ModInt b) {ModInt res = a; res += b; return res;}
friend constexpr ModInt operator - (ModInt a, ModInt b) {ModInt res = a; res -= b; return res;}
friend constexpr ModInt operator * (ModInt a, ModInt b) {ModInt res = a; res *= b; return res;}
friend constexpr ModInt operator / (ModInt a, ModInt b) {ModInt res = a; res /= b; return res;}
friend constexpr bool operator == (ModInt a, ModInt b) {return (a.n == b.n);}
friend constexpr bool operator != (ModInt a, ModInt b) {return (a.n != b.n);}
friend constexpr istream &operator >> (istream &is, ModInt &a) {
int x = 0; is >> x;
a = ModInt(x); return is;
}
friend constexpr ostream &operator << (ostream &os, const ModInt &a) {return (os << (a.n));}
} ;
template <>
long long ModLL <0> :: Mod = (long long)(1E18) + 9;
template <>
int ModInt <0> :: Mod = 998244353;
using P = ModInt <998244353>;
using i64 = long long;
const int N = 2E5 + 5;
int n, m, bs, h[N], len, ll, dfn[N], ccnt;
vector <pair <int, int>> vs;
int dep[N], S = 19, yf[N][20], a[N << 1];
P f[N][2], fp[N][2], g[N][2], re[N][2];
bool key[N];
vector <int> G[N], E[N], ks;
map <pair <int, int>, int> rid, dir;
void find(int x, int fa) {
dfn[x] = ++ccnt;
f[x][0] = f[x][1] = 1;
for (auto v : G[x]) {
if (v == fa) continue;
if (dfn[v] && dfn[x] < dfn[v]) {
rid[make_pair(x, v)] = rid[make_pair(v, x)] = ++bs;
vs.emplace_back(x, v);
key[x] = key[v] = 1;
continue;
} else if (dfn[v]) continue;
dep[v] = dep[x] + 1; yf[v][0] = x;
for (int i = 1; i <= S; ++i)
yf[v][i] = yf[yf[v][i - 1]][i - 1];
find(v, x);
f[x][0] *= f[v][0] + f[v][1];
f[x][1] *= f[v][0];
}
}
int glca(int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
for (int i = S; ~i; --i) if (dep[u] - (1 << i) >= dep[v]) u = yf[u][i];
if (u == v) return u;
for (int i = S; ~i; --i) if (yf[u][i] != yf[v][i])
u = yf[u][i], v = yf[v][i];
return yf[u][0];
}
void conn(int x, int y) {E[x].emplace_back(y); E[y].emplace_back(x);}
void build() {
for (int i = 1; i <= n; ++i) if (key[i])
h[++len] = i;
sort(h + 1, h + len + 1, [&](int x, int y) {return dfn[x] < dfn[y];});
for (int i = 1; i <= len; ++i) a[++ll] = h[i];
for (int i = 1; i < len; ++i) a[++ll] = glca(h[i], h[i + 1]);
a[++ll] = 1;
sort(a + 1, a + 1 + ll, [&](int x, int y) {return dfn[x] < dfn[y];});
ll = unique(a + 1, a + 1 + ll) - a - 1;
for (int i = 1; i < ll; ++i) {
int lc = glca(a[i], a[i + 1]);
conn(lc, a[i + 1]);
}
}
P k[40][2][2];
void sx(int x, int v) {
int d = dir[make_pair(x, v)], lt = x;
x = yf[x][0];
k[d][0][0] = k[d][0][1] = 1;
k[d][1][0] = 1; k[d][1][1] = 0;
P pre[2][2];
while (x != v) {
for (int i : {0, 1}) for (int j : {0, 1})
pre[i][j] = k[d][i][j];
fp[x][0] = f[x][0] / (f[lt][0] + f[lt][1]);
fp[x][1] = f[x][1] / f[lt][0];
k[d][0][0] = fp[x][0] * pre[0][0] + fp[x][1] * pre[1][0];
k[d][0][1] = fp[x][0] * pre[0][1] + fp[x][1] * pre[1][1];
k[d][1][0] = fp[x][0] * pre[0][0];
k[d][1][1] = fp[x][0] * pre[0][1];
lt = x;
x = yf[x][0];
}
g[v][0] /= (f[lt][0] + f[lt][1]);
g[v][1] /= f[lt][0];
}
void xs(int x, int fa) {
ks.emplace_back(x);
for (auto v : E[x]) {
if (v == fa) continue;
int now = dir.size() / 2;
dir[make_pair(v, x)] = dir[make_pair(x, v)] = now + 1;
sx(v, x);
xs(v, x);
}
}
void DP(int x, int fa) {
for (auto v : E[x]) {
if (v == fa) continue;
int d = dir[make_pair(v, x)];
DP(v, x);
re[x][0] *= k[d][0][0] * re[v][0] + k[d][0][1] * re[v][1];
re[x][1] *= k[d][1][0] * re[v][0] + k[d][1][1] * re[v][1];
}
}
signed main(void) {
ios :: sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin >> n >> m;
for (int i = 1; i <= m; ++i) {
int u, v; cin >> u >> v;
G[u].emplace_back(v);
G[v].emplace_back(u);
}
dep[1] = 1; for (int i = 0; i <= S; ++i) yf[1][i] = 1;
find(1, 0); build();
for (int i = 1; i <= n; ++i) g[i][0] = f[i][0], g[i][1] = f[i][1];
xs(1, 0);
P ans = 0;
for (int i = 1; i <= n; ++i) re[i][0] = g[i][0], re[i][1] = g[i][1];
for (int i = 0; i < (1 << bs); ++i) {
for (auto x : ks) re[x][0] = g[x][0], re[x][1] = g[x][1];
for (int j = 0; j < bs; ++j) {
auto [x, y] = vs[j];
if (i >> j & 1) re[x][1] = 0;
else re[x][0] = 0, re[y][1] = 0;
}
DP(1, 0);
ans += re[1][0] + re[1][1];
} cout << ans << '\n';
return 0;
}

浙公网安备 33010602011771号