题解:P5417 [CTSC2016] 萨菲克斯·阿瑞

posted on 2025-01-08 01:06:56 | under | source

神仙计数题。

肯定考虑判定 SA 是否合法,先考虑判定 \(p\) 是否是原串 \(S\) 的 SA,有必要条件:

  • \(rk_{p_i+1}<rk_{p_{i+1}+1}\)\(S_{p_i}\le S_{p_{i+1}}\)
  • 反之:\(S_{p_i}<S_{p_{i+1}}\)

我们称之为不等式链,易证也是充分的。于是对于一个 SA,可以构建唯一的不等式链 \(l\)\(p\) 对于 \(S\) 合法,当且仅当 \(l\) 对于 \(S\) 合法。

然后可以构造地判断是否存在 \(S\) 满足 \(l\),贪心即可:将 \(l\) 视作若干 \(<\) 号划分的 \(\le\) 号段,从小到大考虑字符并填入段,能填满即填满,溢出就扔掉即可。而且这玩意是方便 dp 的,分讨上述两种情况即可,之后再细讲。

为了方便讨论,扔掉 \(c\) 的限制,认为每种字符有无限种,但是 \(S\) 的字符优先填写小的(不然算重了)。

串到 SA 是单射,SA 到 \(l\) 是单射。计数的话肯定是对串统计比较容易,但构不成双射就比较烦人。大眼观察,你可能注意到某些情况下是有双射的:例如当 \(l\) 划分的段数恰为字符集大小时,则 \(l\) 唯一确定一个 \(S\)。现在这个 \(S\) 容易计数,多重排列即可,但是会算重。

我们不妨枚举 \(l\),考虑有多少 SA 映射到它,记为 \(f(l)\)。记 \(a_1\dots a_k\)\(l\) 中极大 \(\le\) 号段的大小,延续上述想法记 \(g(l)=\frac{n!}{\prod a_i!}\),即一个多重排列。考虑统计情况:

  • 对于可以映射到 \(l\)\(p\),必然存在且唯一存在一个映射到它的 \(S\) 满足该多重排列的形式。可以构造地说明,考虑填写不等式链,易发现方案唯一。
  • 对于映射到其它不等式链 \(l_2\ne l\),可以发现 \(l_2\) 相较于 \(l\) 只可能是 \(<\) 号变为 \(\le\) 号。还是考虑填写 \(l_2\) 以判定是否会被统计,那么将 \(<\) 改为 \(\le\) 并不会影响填写方案唯一,但是改 \(\le\)\(<\) 却会使得字符集不够用。
  • 称这样的 \(l_2\subseteq l\),易发现所有映射到 \(l_2\)\(q\),必然存在且唯一存在一个映射到 \(q\)\(S\) 被统计,还是构造,不再赘述。

于是实际上有 \(g(l)=\sum\limits_{l_2\subseteq l} f(l_2)\)。考虑容斥,则 \(f(l)=\sum\limits_{l_2\subseteq l} (-1)^{|l|-|l_2|}g(l_2)\)

回到原题,我们试着贪心构造出可能的不等式链,然后统计对应的 SA。原先的统计方法仍然适用。因为假如一个 \(\le\) 段是由多种字符构成的,可以将它们合并为同一种字符,然后回归上述证明即可。

可以 dp 解决,回到先前提到的 dp,无非是多了一种转移情况:容斥,允许不填满一段结束。过程中维护容斥系数、多重排列系数即可。转移是平凡的。复杂度 \(O(n^4)\),前缀和优化为 \(O(n^3)\)

代码

#include<bits/stdc++.h>
using namespace std;

#define int long long
#define ADD(a, b) a = (a + b) % mod
const int N = 5e2 + 5, mod = 1e9 + 7;
int n, m, c, f[N][N], g[N][N], h[N][N], jc[N], jcinv[N], ans;

inline int qstp(int a, int k) {int res = 1; for(; k; a = a * a % mod, k >>= 1) if(k & 1) res = res * a % mod; return res;}
signed main(){
	jc[0] = jcinv[0] = 1;
	for(int i = 1; i < N; ++i) jcinv[i] = qstp(jc[i] = jc[i - 1] * i % mod, mod - 2);
	cin >> n >> m;
	f[0][0] = jc[n];
	for(int i = 1; i <= m; ++i){
		scanf("%lld", &c); 
		if(!c) continue;
		swap(f, g), memset(f, 0, sizeof f);
		for(int j = 0; j <= n; ++j)
			for(int k = 0; k <= j; ++k){
				if(j + c <= n) ADD(f[j + c][k + c], g[j][k]); 
				if(k) h[j][k] = (h[j - 1][k - 1] + g[j - 1][k - 1]) % mod;
				if(k >= c + 1) h[j][k] = (h[j][k] - g[j - c - 1][k - c - 1] + mod) % mod; 
				ADD(f[j][0], h[j][k] * jcinv[k] % mod);
				ADD(f[j][k], h[j][k] * (mod - 1) % mod); 
			} 
		ADD(ans, f[n][0]);
	}
	cout << ans;
	return 0;
}
posted @ 2026-01-15 08:18  Zwi  阅读(1)  评论(0)    收藏  举报