题解:uoj214 合唱队形

题意:给出 \(n\) 个元素可解锁的字符集,同时给出一个长为 \(m\) 的字符串 \(t\),每次随机解锁一个元素的一个字符,可以多次解锁同一个,问期望多少次后,存在一种方式让这 \(n\) 个元素选择一个字符按顺序构成一个字符串,其中有子串为 \(t\)\(n,m\le 30\)

做法:

首先看到这个题是要 \(t\) 第一次出现的时间,这个不会做,但是可以直接 min-max 容斥,转化成我枚举一个串的开头集合 \(S\),然后要求这个集合里的元素都要能构成 \(t\)。就需要要求某些元素一定需要解锁一些字符集。我们现在就需要一个集合的元素全部出现的时间的期望。

这个好像直接也不是很好做,我们考虑再容斥一步,容斥哪些元素没有出现。同时我们把出现的期望,改成在 \([0,\cdots]\) 的时间内有元素没有出现的概率。那么设总和为 \(sum\)\(v(S)\) 是我总共需要多少个元素,那么柿子为:

\[\sum_{t=0}\sum_{i=1}^{v(S)}(-1)^{i+1}(\frac{sum-i}{sum})^t\binom{v(S)}{i} \]

然后直接交换求和式再等比数列求和,可以得到为:

\[\sum_{i=1}^{v(S)}(-1)^{i+1}\frac{sum}i\binom{v(S)}{i} \]

直接枚举 \(S\) 然后对着这个柿子计算即可做到 \(2^{n-m}nm\)

接下来我们将给出另一个 \(O(2^mn^2m)\) 的做法,这两个做法拼起来可以得到满分。

我们考虑我们对于这个柿子其实只需要知道 \(v(S)\),注意到这个时候串长比较小,考虑对于第 \(i\) 个元素,我只需要知道前面 \(m\) 个位置是否是开头就可以计算 \(i\) 需要多少个了,所以可以考虑 \(dp_{i,s,x}\) 代表目前是第 \(i\) 个元素,\(s\)\(i\) 前面 \(m\) 位的开头集合,\(x\) 是只考虑前 \(i\) 个元素所需要的目前的 \(v(S)\)。直接转移即可。最后再按这个贡献计算即可。

记得特判答案为 \(0\) 需要输出 \(-1\)

代码:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 35, mod = 998244353;
int n, m, v[maxn], use[maxn], c[maxn], cnt[(1 << 20) + 5], inv[maxn * maxn], sum;
int C[maxn * maxn][maxn * maxn];
string s[maxn], t;
int get_v(int x) {
	int cnt = 0;
	while(x)
		cnt += x & 1, x /= 2;
	return cnt;
}
int T;
int dp[2][(1 << 14)][maxn * 15];
signed main() {
	cin >> T;
	while(T--) {
		cin >> n >> m;
		sum = 0;
		for (int i = 1; i <= n; i++) {
			cin >> s[i];
			c[i] = 0;
			for (int j = 0; j < s[i].size(); j++)
				c[i] |= (1 << s[i][j] - 'a');
			sum += s[i].size();
		}
		cin >> t; t = ' ' + t;
		if(n - m <= 16) {
			C[0][0] = 1;
			inv[0] = inv[1] = 1;
			for (int i = 2; i <= 26 * n; i++)
				inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
			for (int i = 1; i <= 26 * n; i++) {
				C[i][0] = 1;
				for (int j = 1; j <= i; j++)
					C[i][j] = (C[i - 1][j - 1] + C[i - 1][j]) % mod;
			}
			for (int i = 1; i < (1 << n - m + 1); i++)
				cnt[i] = cnt[i >> 1] + (i & 1);
			int ans = 0;
			for (int s = 0; s < (1 << n - m + 1); s++) {
				for (int i = 1; i <= n; i++)
					use[i] = 0;
				for (int i = 1; i <= n - m + 1; i++) {
					if((s >> i - 1) & 1) {
						for (int j = 1; j <= m; j++)
							use[i + j - 1] |= (1 << t[j] - 'a');
					}
				}
				bool f = 1;
				for (int i = 1; i <= n; i++)
					if((c[i] & use[i]) != use[i])
						f = 0;
				if(!f)
					continue;
				int ct = 0, res = 0;
				for (int i = 1; i <= n; i++)
					ct += get_v(c[i] & use[i]);
				for (int i = 1; i <= ct; i++)
					res = (res + 1ll * (i % 2 ? 1 : mod - 1) * C[ct][i] % mod * sum % mod * inv[i] % mod) % mod;
				ans = (ans + 1ll * (cnt[s] % 2 ? 1 : mod - 1) * res % mod) % mod;
			}
			cout << (ans ? ans : -1) << endl;
		}
		else {
			C[0][0] = 1;
			inv[0] = inv[1] = 1;
			for (int i = 2; i <= 26 * n; i++)
				inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
			for (int i = 1; i <= 26 * n; i++) {
				C[i][0] = 1;
				for (int j = 1; j <= i; j++)
					C[i][j] = (C[i - 1][j - 1] + C[i - 1][j]) % mod;
			}
			for (int i = 1; i < (1 << m); i++)
				cnt[i] = cnt[i >> 1] + (i & 1);
			memset(dp, 0, sizeof(dp));
			dp[1][1][0] = mod - 1, dp[1][0][0] = 1;
			int cur = 1;
			int all = (1 << m) - 1;
			for (int i = 1; i <= n; i++) {
			//	cout << i << endl;
				memset(dp[cur ^ 1], 0, sizeof(dp[cur ^ 1]));
				for (int s = 0; s < (1 << m); s++) {
					int use = 0;
					for (int j = 1; j <= m; j++)
						if((s >> j - 1) & 1)
							use |= (1 << t[j] - 'a');
					if((use & c[i]) != use)
						continue;
					int v = get_v(c[i] & use);
					for (int j = m * i; j >= v; j--) 
						dp[cur][s][j] = dp[cur][s][j - v];
					for (int j = 0; j < v; j++)
						dp[cur][s][j] = 0;
					for (int j = v; j <= m * i; j++) {
						if(!dp[cur][s][j])
							continue;
						dp[cur ^ 1][(s << 1) & all][j] = (dp[cur ^ 1][(s << 1) & all][j] + dp[cur][s][j]) % mod;
						if(i + 1 <= n - m + 1)
							dp[cur ^ 1][(s << 1 | 1) & all][j] = (dp[cur ^ 1][(s << 1 | 1) & all][j] - dp[cur][s][j] + mod) % mod;
					}
				}
				cur ^= 1;
			}
			cur ^= 1;
			int ans = 0;
			for (int s = 0; s < (1 << m); s += (1 << m - 1))
				for (int j = 0; j <= m * n; j++) {
					if(!dp[cur][s][j])
						continue;
					int res = 0;
					for (int k = 1; k <= j; k++)
						res = (res + 1ll * (k % 2 == 0 ? 1 : mod - 1) * C[j][k] % mod * sum % mod * inv[k] % mod) % mod;
					ans = (ans + 1ll * res * dp[cur][s][j] % mod) % mod;
				}
			cout << (ans ? ans : -1) << endl;
		}
	}
	return 0;
}
/*
2
3 2
acb
cb
ab
cb
3 2
acb
cb
ab
cb
*/
posted @ 2025-10-24 19:34  LUlululu1616  阅读(14)  评论(1)    收藏  举报