# 「新年的追逐战」题解

$G_1, G_2$都不是二分图的时候，形成一个非二分图的连通块。
$G_1, G_2$有恰好一个二分图的时候，形成一个二分图的连通块。
$G_1, G_2$全是二分图的时候，形成两个二分图的连通块。

$M = \max_{i = 1}^{n} m_i$$F(x) = \sum_{i = 0}^{M} \frac{2^{\frac{i(i - 1)}{2}}}{i!} x^i$。则$\ln F(x)$的每项系数就是$i$阶连通图的个数的EGF。$F(x) \ln F(x)$就是$n$阶图总连通块数的EGF。于是$A_i$就求出来了。

$G(x) = \sum_{i = 0}^{M} \sum_{j = 0}^{i} \frac{2^{j(i - j)}}{j!(i - j)!}$，它的每一项意义表示把$i$阶图的每个点二染色以及连边的情况数。那么$\frac{\ln G(x)}{2}$就是$i$阶连通二分图的个数的EGF。那么$\frac{F(x) \ln G(x)}{2}$就是$n$阶图的总的二分图的连通块数的EGF。

#include <bits/stdc++.h>
#define debug(x) cerr << #x << " " << (x) << endl
using namespace std;

const int N = 100005, M = 100005, K = 21;
const long long mod = 998244353ll, R = 3ll;

template <class T>
int sgn = 1;
char ch;
x = 0;
for (ch = getchar(); (ch < '0' || ch > '9') && ch != '-'; ch = getchar()) ;
if (ch == '-') ch = getchar(), sgn = -1;
for (; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
x *= sgn;
}
template <class T>
void write (T x) {
if (x < 0) putchar('-'), write(-x);
else if (x < 10) putchar(x + '0');
else write(x / 10), putchar(x % 10 + '0');
}

int n, a[N], m = 0;
long long w[K][M << 2], inv[M << 2], fac[M << 2], inv_fac[M << 2];
long long qpow (long long x, long long y) {
long long res = 1ll;
for (; y; y >>= 1, x = x * x % mod) {
if (y & 1) res = res * x % mod;
}
return res;
}
void init () {
inv[1] = 1ll;
for (int i = 2; i <= (m + 2 << 2); i++) inv[i] = (mod / i) * (mod - inv[mod % i]) % mod;
fac[0] = 1ll;
for (int i = 1; i <= (m + 2 << 2); i++) fac[i] = fac[i - 1] * i % mod;
inv_fac[0] = 1ll;
for (int i = 1; i <= (m + 2 << 2); i++) inv_fac[i] = inv_fac[i - 1] * inv[i] % mod;

for (int i = 1, j = 0; i <= (m + 2 << 2); i <<= 1, j++) {
w[j][0] = 1ll, w[j][1] = qpow(R, (mod - 1) >> j);
for (int k = 2; k <= i; k++) w[j][k] = w[j][k - 1] * w[j][1] % mod;
}
}

int rev[M << 2];
void ntt (long long *f, int len, int ty) {
rev[0] = 0;
for (int i = 1; i < len; i++) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1) rev[i] |= (len >> 1);
}
for (int i = 0; i < len; i++) {
int j = rev[i];
if (j > i) swap(f[i], f[j]);
}
for (int i = 1, j = 1; i < len; i <<= 1, j++) {
for (int k = 0; k < len; k += (i << 1)) {
for (int l = 0; l < i; l++) {
long long u = f[k + l], v = f[k + l + i];
long long rt = w[j][ty > 0 ? l : (i << 1) - l];
f[k + l] = (u + rt * v) % mod;
f[k + l + i] = (u + rt * (mod - v)) % mod;
}
}
}
if (ty < 0) {
for (int i = 0; i < len; i++) f[i] = f[i] * inv[len] % mod;
}
}

long long poly[M << 2], poly_inv[M << 2], tmp1[M << 2], tmp2[M << 2];
void get_inv (int len) {
for (int i = 0; i < (len << 2); i++) poly_inv[i] = 0ll;
poly_inv[0] = 1ll;
for (int i = 1; i < len; i <<= 1) {
for (int j = 0; j < (i << 2); j++) tmp1[j] = tmp2[j] = 0ll;
for (int j = 0; j < (i << 1); j++) tmp1[j] = poly[j], tmp2[j] = poly_inv[j];
ntt(tmp1, i << 2, 1), ntt(tmp2, i << 2, 1);
for (int j = 0; j < (i << 2); j++) tmp1[j] = (2ll * tmp2[j] + mod - tmp1[j] * tmp2[j] % mod * tmp2[j] % mod) % mod;
ntt(tmp1, i << 2, -1);
for (int j = 0; j < (i << 1); j++) poly_inv[j] = tmp1[j];
}
}

long long poly_ln[M << 2];
void get_ln (int len) {
get_inv(len + 1);
for (int i = 0; i < (len << 2); i++) tmp1[i] = tmp2[i] = poly_ln[i] = 0ll;
for (int i = 1; i < len; i++) tmp1[i - 1] = poly[i] * i % mod;
for (int i = 0; i < len; i++) tmp2[i] = poly_inv[i];
int L = 1;
while (L < (len << 1)) L <<= 1;
ntt(tmp1, L, 1), ntt(tmp2, L, 1);
for (int i = 0; i < L; i++) tmp1[i] = tmp1[i] * tmp2[i] % mod;
ntt(tmp1, L, -1);
for (int i = 0; i < len; i++) poly_ln[i + 1] = tmp1[i] * inv[i + 1] % mod;
}

long long total_g[M], ans1[M], ans2[M], ans3[M], mul = 1ll, ans = 0ll;
int main () {
for (int i = 0; i < n; i++) read(a[i]), m = max(m, a[i]);
init();
for (int i = 0; i <= m; i++) total_g[i] = qpow(2ll, 1ll * i * (i - 1) / 2);
for (int i = 0; i <= m; i++) poly[i] = total_g[i] * inv_fac[i] % mod;
get_ln(m + 1);

int L = 1;
while (L <= (m + 1 << 1)) L <<= 1;
for (int i = 0; i < L; i++) tmp1[i] = tmp2[i] = 0ll;
for (int i = 0; i <= m; i++) tmp1[i] = poly[i], tmp2[i] = poly_ln[i];
ntt(tmp1, L, 1), ntt(tmp2, L, 1);
for (int i = 0; i < L; i++) tmp1[i] = tmp1[i] * tmp2[i] % mod;
ntt(tmp1, L, -1);
for (int i = 1; i <= m; i++) ans1[i] = tmp1[i] * fac[i] % mod;

for (int i = 0; i <= m; i++) poly[i] = qpow(inv[2], 1ll * i * (i - 1) / 2) * inv_fac[i] % mod;
for (int i = 0; i < L; i++) tmp1[i] = 0ll;
for (int i = 0; i <= m; i++) tmp1[i] = poly[i];
ntt(tmp1, L, 1);
for (int i = 0; i < L; i++) tmp1[i] = tmp1[i] * tmp1[i] % mod;
ntt(tmp1, L, -1);
for (int i = 0; i <= m; i++) poly[i] = tmp1[i] * total_g[i] % mod;
get_ln(m + 1);

for (int i = 0; i < L; i++) tmp1[i] = tmp2[i] = 0ll;
for (int i = 0; i <= m; i++) {
tmp1[i] = poly_ln[i] * inv[2] % mod;
tmp2[i] = total_g[i] * inv_fac[i] % mod;
}
ntt(tmp1, L, 1), ntt(tmp2, L, 1);
for (int i = 0; i < L; i++) tmp1[i] = tmp1[i] * tmp2[i] % mod;
ntt(tmp1, L, -1);
for (int i = 0; i <= m; i++) {
ans2[i] = tmp1[i] * fac[i] % mod;
poly[i] = total_g[i];
}
for (int i = 1; i <= m; i++) {
ans1[i] = (ans1[i] + mod - poly[i - 1] * i % mod) % mod;
ans2[i] = (ans2[i] + mod - poly[i - 1] * i % mod) % mod;
ans3[i] = (poly[i] + mod - poly[i - 1]) * i % mod;
}
for (int i = 0; i < n; i++) mul = mul * poly[a[i]] % mod * a[i] % mod;
ans = mul, mul = 1ll;
for (int i = 0; i < n; i++) mul = mul * ans3[a[i]] % mod;
ans = (ans + mod - mul) % mod, mul = 1ll;
for (int i = 0; i < n; i++) mul = mul * (ans1[a[i]] + ans2[a[i]]) % mod;
ans = (ans + mul * inv[2]) % mod, mul = 1ll;
for (int i = 0; i < n; i++) mul = mul * (ans1[a[i]] + mod - ans2[a[i]]) % mod;
ans = (ans + mul * inv[2]) % mod;
write(ans), putchar('\n');
return 0;
}

posted @ 2020-06-19 11:00  unzcjouhi  阅读(278)  评论(1编辑  收藏