Educational Codeforces Round 118 (Rated for Div. 2) - F. Tree Coloring 题解
题意
给定一棵树,要求计算,给节点染色,要求每个节点 \(c_k \neq c_{p_k} - 1\) ,统计方案数 \((mod\ \ 998\ 244\ 353)\)
思路
容斥枚举破坏 \(i\) 个条件下的方案数,对于每个节点,都有出度种方法造成 \(1\) 个贡献,对于每个节点的生成函数即为
\[g(x) = 1 + c \cdot x
\]
其余节点染色 \((n-i)!\) 排列一下就好
启发式合并或分治NTT即可
代码
#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define i64 long long
#define poly std::vector<int>
// dont visit a[m] when a.size() <= m
// (a = fastpow(c,n-m+1,m+1)).resize(m+1);
// i64 res = a[m] - b[m];
// (b = fastpow(d,n-m+1,m+1)).resize(m+1);
constexpr int MOD = 998244353;
namespace Poly { // remember to resize
const int N = (1 << 21), g = 3;
inline int power(int x, int p) {
int res = 1;
for (; p; p >>= 1, x = (ll)x * x % MOD)
if (p & 1)
res = (ll)res * x % MOD;
return res;
}
inline int fix(const int x) { return x >= MOD ? x - MOD : x; }
void dft(poly& A, int n) {
static ull W[N << 1], *H[30], *las = W, mx = 0;
for (; mx < n; mx++) {
H[mx] = las;
ull w = 1, wn = power(g, (MOD - 1) >> (mx + 1));
for(int i=0;i<1<<n;++i) *las++ = w, w = w * wn % MOD;
}
if (A.size() != (1 << n))
A.resize(1 << n);
static ull a[N];
for (int i = 0, j = 0; i < (1 << n); ++i) {
a[i] = A[j];
for (int k = 1 << (n - 1); (j ^= k) < k; k >>= 1);
}
for (int k = 0, d = 1; k < n; k++, d <<= 1)
for (int i = 0; i < (1 << n); i += (d << 1)) {
ull *l = a + i, *r = a + i + d, *w = H[k], t;
for (int j = 0; j < d; j++, l ++, r++) {
t = (*r) * (*w++) % MOD;
*r = *l + MOD - t, *l += t;
}
}
for(int i=0;i<1<<n;++i) A[i] = a[i] % MOD;
}
void idft(poly &a, int n) {
a.resize(1 << n), reverse(a.begin() + 1, a.end());
dft(a, n);
int inv = power(1 << n, MOD - 2);
for(int i=0;i<1<<n;++i) a[i] = (ll)a[i] * inv % MOD;
}
poly FIX(poly a) {
while (!a.empty() && !a.back()) a.pop_back();
return a;
}
// remember to resize
poly mul(poly a, poly b, int t = 1) {
if (t == 1 && a.size() + b.size() <= 24) {
poly c(a.size() + b.size(), 0);
for(int i=0;i<a.size();++i) for(int j=0;j<b.size();++j) c[i + j] = (c[i + j] + (ll)a[i] * b[j]) % MOD;
return FIX(c);
}
int n = 1, aim = a.size() * t + b.size();
while ((1<<n) <= aim) n++;
dft(a, n); dft(b, n);
if (t == 1)
for(int i=0;i<1<<n;++i) a[i] = (ll) a[i] * b[i] % MOD;
else
for(int i=0;i<1<<n;++i) a[i] = (ll) a[i] * a[i] % MOD * b[i] % MOD;
idft(a, n); a.resize(aim);
return FIX(a);
}
};
using namespace Poly; // remember to resize
void norm(int&x) {
if(x>=MOD) x -= MOD;
if(x<0) x += MOD;
}
int main(int argc, char const *argv[])
{
std::ios_base::sync_with_stdio(false);
std::cin.tie(nullptr); std::cout.tie(nullptr);
int n;
std::cin >> n;
std::vector<std::vector<int> > g(n, std::vector<int>());
for(int i=0;i<n-1;++i) {
int u,v;
std::cin >> u >> v;
--u; --v;
g[u].push_back(v);
g[v].push_back(u);
}
auto dnc = [&](auto dnc,int l,int r) {
if(r - l == 1) {
return (poly) {1, (int) g[l].size() - (l != 0)};
}
int mid = l + r >> 1;
return mul(dnc(dnc,l,mid), dnc(dnc,mid,r));
};
int res = 0;
poly ans = dnc(dnc,0,n);
ans.resize(n+1);
std::vector<int> fac(n+1);
fac[0] = fac[1] = 1;
for(int i=2;i<=n;++i) {
fac[i] = 1ll * fac[i-1] * i % MOD;
}
for(int i=0;i<=n;++i) {
int thiz = 1ll * fac[n - i] * ans[i] % MOD;
norm(
res += (i&1 ? MOD - thiz : thiz)
);
}
std::cout << res;
return 0;
}
Living with bustle, hearing of isolation.

浙公网安备 33010602011771号