CodeForces 1010F Tree
educational 的。另一道类似的题是 [ABC269Ex] Antichain。
考虑令 \(b_u = a_u - \sum\limits_{v \in son_u} a_v\)。那么 \(\sum\limits_{i = 1}^n b_i = a_1 = x\),且 \(\forall i \in [1, n], b_i \ge 0\)。所以最后连通块内有 \(y\) 个点,那么贡献系数为 \(\binom{x + y - 1}{y - 1}\)。所以转为计算包含 \(1\) 的连通块有 \(i\) 个点的方案数。
考虑经典树形背包,设 \(f_{u, i}\) 为 \(u\) 子树内包含 \(u\) 的连通块点数为 \(i\) 的方案数。特别地 \(f_{u, 0} = 1\) 表示转移上去断掉这条边。记 \(L, R\) 分别为 \(u\) 的左右儿子,有:
写成生成函数的形式,就是:
你发现这玩意直接做优化不了,因为 \(F_u(x)\) 的次数是 \(sz_u\) 级别的。这启发我们想到重链剖分。
具体地,考虑在重链顶处计算重链顶的多项式 \(F_u(x)\)。设重链上的点从浅到深依次为 \(a_1, a_2, \ldots, a_n\),\(a_i\) 的轻儿子为 \(b_i\)(为了方便若没有轻儿子则 \(b_i = 0\),\(F_0(x) = 1\)),我们有:
以此类推,设 \(G_i = x F_{b_i}(x)\),那么 \(F_u(x) = G_1 (G_2(\ldots (G_n + 1)) \ldots + 1) + 1 = (\sum\limits_{i = 1}^n \prod\limits_{j = 1}^i G_j) + 1\)。
这个东西可以分治 NTT 计算。具体就是每次递归 \([l, r]\) 返回一个二元组 \((\sum\limits_{i = l}^r \prod\limits_{j = l}^i G_j, \prod\limits_{i = l}^r G_i)\),那么 \([l, mid]\) 和 \([mid + 1, r]\) 的信息就可以合并了。
考虑每次计算的 \(G_i\) 次数之和为一棵树所有轻儿子的子树大小 \(= O(n \log n)\),分治 NTT 再带两个 \(\log\),总时间复杂度就是 \(O(n \log^3 n)\)。实际运行效率还可以。
code
// Problem: F. Tree
// Contest: Codeforces - Codeforces Round 499 (Div. 1)
// URL: https://codeforces.com/contest/1010/problem/F
// Memory Limit: 256 MB
// Time Limit: 7000 ms
// 
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 500100;
const ll mod = 998244353, gg = 3;
inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}
typedef vector<ll> poly;
int r[maxn];
inline poly NTT(poly a, int op) {
	int n = (int)a.size();
	for (int i = 0; i < n; ++i) {
		if (i < r[i]) {
			swap(a[i], a[r[i]]);
		}
	}
	for (int k = 1; k < n; k <<= 1) {
		ll wn = qpow(op == 1 ? gg : qpow(gg, mod - 2), (mod - 1) / (k << 1));
		for (int i = 0; i < n; i += (k << 1)) {
			ll w = 1;
			for (int j = 0; j < k; ++j, w = w * wn % mod) {
				ll x = a[i + j], y = w * a[i + j + k] % mod;
				a[i + j] = (x + y) % mod;
				a[i + j + k] = (x - y + mod) % mod;
			}
		}
	}
	if (op == -1) {
		ll inv = qpow(n, mod - 2);
		for (int i = 0; i < n; ++i) {
			a[i] = a[i] * inv % mod;
		}
	}
	return a;
}
inline poly operator * (poly a, poly b) {
	a = NTT(a, 1);
	b = NTT(b, 1);
	int n = (int)a.size();
	for (int i = 0; i < n; ++i) {
		a[i] = a[i] * b[i] % mod;
	}
	a = NTT(a, -1);
	return a;
}
inline poly operator + (poly a, poly b) {
	int n = (int)a.size() - 1, m = (int)b.size() - 1;
	poly res(max(n, m) + 1);
	for (int i = 0; i <= n; ++i) {
		res[i] = (res[i] + a[i]) % mod;
	}
	for (int i = 0; i <= m; ++i) {
		res[i] = (res[i] + b[i]) % mod;
	}
	return res;
}
inline poly mul(poly a, poly b) {
	int n = (int)a.size() - 1, m = (int)b.size() - 1, k = 0;
	while ((1 << k) < n + m + 1) {
		++k;
	}
	for (int i = 1; i < (1 << k); ++i) {
		r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1));
	}
	poly A(1 << k), B(1 << k);
	for (int i = 0; i <= n; ++i) {
		A[i] = a[i];
	}
	for (int i = 0; i <= m; ++i) {
		B[i] = b[i];
	}
	poly res = A * B;
	res.resize(n + m + 1);
	return res;
}
ll n, m;
vector<int> G[maxn];
int sz[maxn], son[maxn], b[maxn], top[maxn];
poly F[maxn], a[maxn];
void dfs(int u, int fa) {
	sz[u] = 1;
	int mx = -1;
	vector<int> S;
	for (int v : G[u]) {
		if (v == fa) {
			continue;
		}
		S.pb(v);
		dfs(v, u);
		sz[u] += sz[v];
		if (sz[v] > mx) {
			son[u] = v;
			mx = sz[v];
		}
	}
	for (int v : S) {
		if (son[u] != v) {
			b[u] = v;
		}
	}
}
void dfs2(int u, int tp) {
	top[u] = tp;
	if (!son[u]) {
		return;
	}
	dfs2(son[u], tp);
	for (int v : G[u]) {
		if (!top[v]) {
			dfs2(v, v);
		}
	}
}
pair<poly, poly> calc(int l, int r) {
	if (l == r) {
		return mkp(a[l], a[l]);
	}
	int mid = (l + r) >> 1;
	auto L = calc(l, mid), R = calc(mid + 1, r);
	return mkp(L.fst + mul(L.scd, R.fst), mul(L.scd, R.scd));
}
void dfs3(int u, int fa) {
	for (int v : G[u]) {
		if (v == fa) {
			continue;
		}
		dfs3(v, u);
	}
	if (u == top[u]) {
		int K = 0;
		for (int v = u; v; v = son[v]) {
			a[++K] = poly(1, 0);
			for (ll x : F[b[v]]) {
				a[K].pb(x);
			}
		}
		auto res = calc(1, K);
		F[u] = res.fst;
		F[u][0] = 1;
	}
}
void solve() {
	scanf("%lld%lld", &n, &m);
	for (int i = 1, u, v; i < n; ++i) {
		scanf("%d%d", &u, &v);
		G[u].pb(v);
		G[v].pb(u);
	}
	dfs(1, -1);
	dfs2(1, 1);
	F[0].pb(1);
	dfs3(1, -1);
	ll ans = 0, fac = 1, mul = 1;
	for (int i = 1; i <= n; ++i) {
		ans = (ans + mul * qpow(fac, mod - 2) % mod * F[1][i]) % mod;
		fac = fac * i % mod;
		mul = mul * ((m + i) % mod) % mod;
	}
	printf("%lld\n", ans);
}
int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}
                    
                
                
            
        
浙公网安备 33010602011771号