AtCoder Beginner Contest 269 Ex Antichain

洛谷传送门

AtCoder 传送门

CF1010F Tree 基本一致。

考虑经典树形背包,设 \(f_{u, i}\)\(u\) 子树内选了 \(i\) 个点的方案数。初始有 \(f_{u, 0} = 1\)。每次考虑合并儿子 \(v\),有转移:

\[f_{u, i + j} \gets f_{u, i} f_{v, j} \]

最后有 \(f_{u, 1} \gets f_{u, 1} + 1\) 表示只选 \(u\)

写成生成函数的形式,就是:

\[F_u(x) = x + \prod\limits_{v \in son_u} F_v(x) \]

你发现这玩意直接做优化不了,因为 \(F_u(x)\) 的次数是 \(sz_u\) 级别的。这启发我们想到重链剖分。

具体地,考虑在重链顶处计算重链顶的多项式 \(F_u(x)\)。设重链上的点从浅到深依次为 \(a_1, a_2, \ldots, a_n\)\(a_i\) 的所有轻儿子的 \(F_u(x)\) 的积为 \(b_i\)(为了方便若没有轻儿子则 \(b_i = 1\)),那么 \(b_i\) 可以分治 NTT 计算。然后有:

\[F_{a_n}(x) = b_n + x \]

\[F_{a_{n - 1}}(x) = b_{n - 1} F_{a_n}(x) + x = b_{n - 1} (b_n + x) + x \]

以此类推,可以得到 \(F_u(x) = b_1 (b_2(\ldots (b_n + x) \ldots) + x) + x = (\sum\limits_{i = 1}^{n - 1} x \prod\limits_{j = 1}^i b_j) + x\)

这个东西可以分治 NTT 计算。具体就是每次递归 \([l, r]\) 返回一个二元组 \((\sum\limits_{i = l}^r \prod\limits_{j = l}^i b_j, \prod\limits_{i = l}^r b_i)\),那么 \([l, mid]\)\([mid + 1, r]\) 的信息就可以合并了。

考虑每次计算的 \(b_i\) 次数之和为一棵树所有轻儿子的子树大小 \(= O(n \log n)\),分治 NTT 再带两个 \(\log\),总时间复杂度就是 \(O(n \log^3 n)\)。可过。

code
// Problem: Ex - Antichain
// Contest: AtCoder - UNICORN Programming Contest 2022(AtCoder Beginner Contest 269)
// URL: https://atcoder.jp/contests/abc269/tasks/abc269_h
// Memory Limit: 1024 MB
// Time Limit: 8000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 200100;
const ll mod = 998244353, gg = 3;

inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}

int n, r[maxn * 5];
vector<int> G[maxn];

typedef vector<ll> poly;

inline poly NTT(poly a, int op) {
	int n = (int)a.size();
	for (int i = 0; i < n; ++i) {
		if (i < r[i]) {
			swap(a[i], a[r[i]]);
		}
	}
	for (int k = 1; k < n; k <<= 1) {
		ll wn = qpow(op == 1 ? gg : qpow(gg, mod - 2), (mod - 1) / (k << 1));
		for (int i = 0; i < n; i += (k << 1)) {
			ll w = 1;
			for (int j = 0; j < k; ++j, w = w * wn % mod) {
				ll x = a[i + j], y = w * a[i + j + k] % mod;
				a[i + j] = (x + y) % mod;
				a[i + j + k] = (x - y + mod) % mod;
			}
		}
	}
	if (op == -1) {
		ll inv = qpow(n, mod - 2);
		for (int i = 0; i < n; ++i) {
			a[i] = a[i] * inv % mod;
		}
	}
	return a;
}

inline poly operator * (poly a, poly b) {
	a = NTT(a, 1);
	b = NTT(b, 1);
	int n = (int)a.size();
	for (int i = 0; i < n; ++i) {
		a[i] = a[i] * b[i] % mod;
	}
	a = NTT(a, -1);
	return a;
}

inline poly operator + (poly a, poly b) {
	int n = (int)a.size() - 1, m = (int)b.size() - 1;
	poly res(max(n, m) + 1);
	for (int i = 0; i <= n; ++i) {
		res[i] = a[i];
	}
	for (int i = 0; i <= m; ++i) {
		res[i] = (res[i] + b[i]) % mod;
	}
	return res;
}

inline poly mul(poly a, poly b) {
	int n = (int)a.size() - 1, m = (int)b.size() - 1, k = 0;
	while ((1 << k) < n + m + 1) {
		++k;
	}
	for (int i = 1; i < (1 << k); ++i) {
		r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1));
	}
	poly A(1 << k), B(1 << k);
	for (int i = 0; i <= n; ++i) {
		A[i] = a[i];
	}
	for (int i = 0; i <= m; ++i) {
		B[i] = b[i];
	}
	poly res = A * B;
	res.resize(n + m + 1);
	return res;
}

int sz[maxn], son[maxn], top[maxn];

void dfs(int u) {
	sz[u] = 1;
	int mx = -1;
	for (int v : G[u]) {
		dfs(v);
		sz[u] += sz[v];
		if (sz[v] > mx) {
			son[u] = v;
			mx = sz[v];
		}
	}
}

void dfs2(int u, int tp) {
	top[u] = tp;
	if (!son[u]) {
		return;
	}
	dfs2(son[u], tp);
	for (int v : G[u]) {
		if (!top[v]) {
			dfs2(v, v);
		}
	}
}

poly F[maxn], a[maxn], b[maxn];

pair<poly, poly> calc(int l, int r) {
	if (l == r) {
		return mkp(a[l], a[l]);
	}
	int mid = (l + r) >> 1;
	auto L = calc(l, mid), R = calc(mid + 1, r);
	return mkp(L.fst + mul(L.scd, R.fst), mul(L.scd, R.scd));
}

poly calc2(int l, int r) {
	if (l == r) {
		return b[l];
	}
	int mid = (l + r) >> 1;
	return mul(calc2(l, mid), calc2(mid + 1, r));
}

void dfs3(int u) {
	for (int v : G[u]) {
		dfs3(v);
	}
	if (u == top[u]) {
		int K = 0;
		for (int v = u; v; v = son[v]) {
			++K;
			if ((int)G[v].size() <= 1) {
				a[K] = poly(1, 1);
				continue;
			}
			int tot = 0;
			for (int w : G[v]) {
				if (w != son[v]) {
					b[++tot] = F[w];
				}
			}
			a[K] = calc2(1, tot);
		}
		auto res = calc(1, K);
		F[u].pb(0);
		for (ll x : res.fst) {
			F[u].pb(x);
		}
		for (int i = 0; i < (int)res.scd.size(); ++i) {
			F[u][i + 1] = (F[u][i + 1] - res.scd[i] + mod) % mod;
			F[u][i] = (F[u][i] + res.scd[i]) % mod;
		}
		F[u][1] = (F[u][1] + 1) % mod;
	}
}

void solve() {
	scanf("%d", &n);
	for (int i = 2, p; i <= n; ++i) {
		scanf("%d", &p);
		G[p].pb(i);
	}
	dfs(1);
	dfs2(1, 1);
	F[0].pb(1);
	dfs3(1);
	for (int i = 1; i <= n; ++i) {
		printf("%lld\n", i < (int)F[1].size() ? F[1][i] : 0LL);
	}
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}

posted @ 2024-01-25 08:01  zltzlt  阅读(38)  评论(0)    收藏  举报