CF1528F AmShZ Farm

太神秘了。

考虑这个题目的限制有一个经典的转化:第 \(i\) 个人要坐到位置 \(a_i\),如果 \(a_i\) 有人就继续查看 \(a_i + 1\)\(n\) 个人之后前 \(n\) 个位置坐满就合法。

先考虑如何对合法的序列计数。这个模型的套路是在最后加一个位置,把序列首位相接形成一个长度为 \(n+1\) 的环,每次如果 \(a_i\) 有人就往环上下一个走。这样最后空出来的位置是 \(n+1\) 才合法,而显然每个位置被空出来的概率是相等的,所以合法的序列数为 \((n+1)^{n}\over n+1\)

再考虑这题,我们发现和上面的计数一样:可以统计出所有序列的答案再除以 \(n+1\)。我们发现,共 \(n+1\) 种数的贡献是相同的,我们只算一种数的贡献就是答案。一种数的贡献容易列出式子:

\[\sum_{i=1}^n \binom{n}{i}i^kn^{n-i} \]

然后这个 \(i^k\) 你一看就会想到第二类斯特林数。后面的过程不难,转下降幂推一推就好了。

#include <bits/stdc++.h>

using namespace std;

const int N = 4e5 + 5, mod = 998244353, G = 3;

int rev[N];

inline int power(int a, int b) {
	int k = b, y = a, t = 1;
	while (k) {
		if (k & 1) t = (1ll * y * t) % mod;
		y = (1ll * y * y) % mod; k >>= 1;
	} return t;
}

const int Gi = power(G, mod - 2);

struct poly {
	int len;
	vector<int> x;
	
	inline void NTT(int flag) {
		for (int i = 0; i <= len; ++i)
			if (i < rev[i]) swap(x[i], x[rev[i]]);
		for (int mid = 1; mid < len; mid <<= 1) {
			const int Wn = power(flag == 1 ? G : Gi, (mod - 1) / (mid << 1));
			for (int l = 0; l < len; l += mid << 1) {
				int w = 1;
				for (int t = 0; t < mid; ++t, w = (1ll * w * Wn) % mod) {
					int a = x[l + t], b = (1ll * w * x[l + mid + t]) % mod;
					x[l + t] = (a + b) % mod;
					x[l + mid + t] = ((a - b) % mod + mod) % mod;
				}
			}
		}
	}
};

inline poly mul(poly a, poly b) {
	poly c; int len = a.len + b.len;
	int tmp = 1, T = 0;
	while (tmp <= len) tmp <<= 1, ++T;
	for (int i = 1; i <= tmp; ++i)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << T - 1);
	c.x.resize(N); c.len = tmp;
	a.len = tmp; b.len = tmp;
	const int inv = power(tmp, mod - 2);
	a.NTT(1); b.NTT(1);
	for (int i = 0; i <= a.len; ++i) c.x[i] = (1ll * a.x[i] * b.x[i]) % mod;
	c.NTT(-1);
	for (int i = 0; i <= a.len; ++i) c.x[i] = (1ll * c.x[i] * inv) % mod;
	c.len = len;
	return c;
}

poly a, b, c;

int fac[N], ifac[N], n, k;

inline void init() {
	a.x.resize(N); b.x.resize(N); a.len = b.len = k;
	fac[0] = 1; for (int i = 1; i <= k; ++i) fac[i] = (1ll * fac[i - 1] * i) % mod;
	ifac[k] = power(fac[k], mod - 2);
	for (int i = k - 1; ~i; --i) ifac[i] = (1ll * (i + 1) * ifac[i + 1]) % mod;
	for (int i = 0; i <= k; ++i) {
		a.x[i] = (1ll * power(i, k) * ifac[i]) % mod;
		b.x[i] = ifac[i]; if (i & 1) b.x[i] = -b.x[i] + mod;
		if (b.x[i] >= mod) b.x[i] -= mod;
	} c = mul(a, b);
}

int main() {
	scanf("%d%d", &n, &k); init();
	int res = 0;
	for (int i = 0, C = 1; i <= k && i <= n; C = 1ll * C * (n - i) % mod, ++i) {
		int del = 1ll * C * c.x[i] % mod;
		del = 1ll * del * power(n + 1, n - i) % mod;
		res += del; if (res >= mod) res -= mod;
	} printf("%d\n", res);
	return 0;
}
posted @ 2023-03-16 16:17  Smallbasic  阅读(26)  评论(0)    收藏  举报