AtCoder Beginner Contest 269 Ex Antichain
和 CF1010F Tree 基本一致。
考虑经典树形背包,设 \(f_{u, i}\) 为 \(u\) 子树内选了 \(i\) 个点的方案数。初始有 \(f_{u, 0} = 1\)。每次考虑合并儿子 \(v\),有转移:
最后有 \(f_{u, 1} \gets f_{u, 1} + 1\) 表示只选 \(u\)。
写成生成函数的形式,就是:
你发现这玩意直接做优化不了,因为 \(F_u(x)\) 的次数是 \(sz_u\) 级别的。这启发我们想到重链剖分。
具体地,考虑在重链顶处计算重链顶的多项式 \(F_u(x)\)。设重链上的点从浅到深依次为 \(a_1, a_2, \ldots, a_n\),\(a_i\) 的所有轻儿子的 \(F_u(x)\) 的积为 \(b_i\)(为了方便若没有轻儿子则 \(b_i = 1\)),那么 \(b_i\) 可以分治 NTT 计算。然后有:
以此类推,可以得到 \(F_u(x) = b_1 (b_2(\ldots (b_n + x) \ldots) + x) + x = (\sum\limits_{i = 1}^{n - 1} x \prod\limits_{j = 1}^i b_j) + x\)。
这个东西可以分治 NTT 计算。具体就是每次递归 \([l, r]\) 返回一个二元组 \((\sum\limits_{i = l}^r \prod\limits_{j = l}^i b_j, \prod\limits_{i = l}^r b_i)\),那么 \([l, mid]\) 和 \([mid + 1, r]\) 的信息就可以合并了。
考虑每次计算的 \(b_i\) 次数之和为一棵树所有轻儿子的子树大小 \(= O(n \log n)\),分治 NTT 再带两个 \(\log\),总时间复杂度就是 \(O(n \log^3 n)\)。可过。
code
// Problem: Ex - Antichain
// Contest: AtCoder - UNICORN Programming Contest 2022(AtCoder Beginner Contest 269)
// URL: https://atcoder.jp/contests/abc269/tasks/abc269_h
// Memory Limit: 1024 MB
// Time Limit: 8000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 200100;
const ll mod = 998244353, gg = 3;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
int n, r[maxn * 5];
vector<int> G[maxn];
typedef vector<ll> poly;
inline poly NTT(poly a, int op) {
int n = (int)a.size();
for (int i = 0; i < n; ++i) {
if (i < r[i]) {
swap(a[i], a[r[i]]);
}
}
for (int k = 1; k < n; k <<= 1) {
ll wn = qpow(op == 1 ? gg : qpow(gg, mod - 2), (mod - 1) / (k << 1));
for (int i = 0; i < n; i += (k << 1)) {
ll w = 1;
for (int j = 0; j < k; ++j, w = w * wn % mod) {
ll x = a[i + j], y = w * a[i + j + k] % mod;
a[i + j] = (x + y) % mod;
a[i + j + k] = (x - y + mod) % mod;
}
}
}
if (op == -1) {
ll inv = qpow(n, mod - 2);
for (int i = 0; i < n; ++i) {
a[i] = a[i] * inv % mod;
}
}
return a;
}
inline poly operator * (poly a, poly b) {
a = NTT(a, 1);
b = NTT(b, 1);
int n = (int)a.size();
for (int i = 0; i < n; ++i) {
a[i] = a[i] * b[i] % mod;
}
a = NTT(a, -1);
return a;
}
inline poly operator + (poly a, poly b) {
int n = (int)a.size() - 1, m = (int)b.size() - 1;
poly res(max(n, m) + 1);
for (int i = 0; i <= n; ++i) {
res[i] = a[i];
}
for (int i = 0; i <= m; ++i) {
res[i] = (res[i] + b[i]) % mod;
}
return res;
}
inline poly mul(poly a, poly b) {
int n = (int)a.size() - 1, m = (int)b.size() - 1, k = 0;
while ((1 << k) < n + m + 1) {
++k;
}
for (int i = 1; i < (1 << k); ++i) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1));
}
poly A(1 << k), B(1 << k);
for (int i = 0; i <= n; ++i) {
A[i] = a[i];
}
for (int i = 0; i <= m; ++i) {
B[i] = b[i];
}
poly res = A * B;
res.resize(n + m + 1);
return res;
}
int sz[maxn], son[maxn], top[maxn];
void dfs(int u) {
sz[u] = 1;
int mx = -1;
for (int v : G[u]) {
dfs(v);
sz[u] += sz[v];
if (sz[v] > mx) {
son[u] = v;
mx = sz[v];
}
}
}
void dfs2(int u, int tp) {
top[u] = tp;
if (!son[u]) {
return;
}
dfs2(son[u], tp);
for (int v : G[u]) {
if (!top[v]) {
dfs2(v, v);
}
}
}
poly F[maxn], a[maxn], b[maxn];
pair<poly, poly> calc(int l, int r) {
if (l == r) {
return mkp(a[l], a[l]);
}
int mid = (l + r) >> 1;
auto L = calc(l, mid), R = calc(mid + 1, r);
return mkp(L.fst + mul(L.scd, R.fst), mul(L.scd, R.scd));
}
poly calc2(int l, int r) {
if (l == r) {
return b[l];
}
int mid = (l + r) >> 1;
return mul(calc2(l, mid), calc2(mid + 1, r));
}
void dfs3(int u) {
for (int v : G[u]) {
dfs3(v);
}
if (u == top[u]) {
int K = 0;
for (int v = u; v; v = son[v]) {
++K;
if ((int)G[v].size() <= 1) {
a[K] = poly(1, 1);
continue;
}
int tot = 0;
for (int w : G[v]) {
if (w != son[v]) {
b[++tot] = F[w];
}
}
a[K] = calc2(1, tot);
}
auto res = calc(1, K);
F[u].pb(0);
for (ll x : res.fst) {
F[u].pb(x);
}
for (int i = 0; i < (int)res.scd.size(); ++i) {
F[u][i + 1] = (F[u][i + 1] - res.scd[i] + mod) % mod;
F[u][i] = (F[u][i] + res.scd[i]) % mod;
}
F[u][1] = (F[u][1] + 1) % mod;
}
}
void solve() {
scanf("%d", &n);
for (int i = 2, p; i <= n; ++i) {
scanf("%d", &p);
G[p].pb(i);
}
dfs(1);
dfs2(1, 1);
F[0].pb(1);
dfs3(1);
for (int i = 1; i <= n; ++i) {
printf("%lld\n", i < (int)F[1].size() ? F[1][i] : 0LL);
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}

浙公网安备 33010602011771号