题解:qoj3227 Substring Pairs
题意:给出 \(n,m,c\),要求找出有多少个字符串对 \((s,t)\) 满足:
-
\(|s| = n,|t|=m\)
-
\(t\) 在 \(s\) 的连续子串中出现过。
-
\(s,t\) 的字符集大小为 \(c\)。
\(n\le 200,m\le 50,c\le1000\)
做法:
首先既然是要出现过,那么很自然地想用整体的方案数 \(c^{n+m}\) 减去没出现过的,现在考虑怎么计算没出现过的。
很自然地再容斥,至少出现 \(0\) 次的减去至少出现 \(1\) 次的……关键是我们发现串会重叠,这样统计出现次数非常麻烦。那么我们思考,如果我们有一个 \(s,t\) 怎么去弄,正常来说找到一次我们就直接跳到 border 进行匹配,那么这里同样,我们假设在 \(i<j\) 这两个位置出现过且他们重叠了,那么我们就扫到 \(i\) 的时候,弹到 \(j\) 而不是弹到 \(i+m\) 的位置。
那我们不知道 \(t\) 该怎么办?其实是我们只关心 border 的情况而不关心具体的值,可以直接爆搜发现 \(m=50\) 时其实只有大概 \(2000\) 个。这里说一下怎么爆搜,每次直接加入一个新 border,因为 \(m\le 50\) 所以可以状压,每次加入一个新 bit,但是这样可能同时导致一些其他的 border 加入,这里直接用并查集维护一下判断一下就可以。
然后有了这个 \(t\) 的 border 状态我们就可以考虑 dp 去帮忙容斥了。我们直接枚举 \(t\),记 \(dp_i\) 代表目前位置到了 \(i\),并且我在末尾放了一个 \(t\) 的结果,转移有两种,一种是我直接叠在了 \(i\) 这个串上面,直接枚举我多出去了 \(x\),那么画图就可以发现我需要一个长为 \(i-x\) 长的 border,直接判即可。另一种是我不叠上,那我直接枚举下一个放的位置去转移就可以,注意中间没有限制的位需要乘上一个 \(c\) 的贡献,因为可以随意选。而对于被覆盖的位置,我们只需要确定 \(t\) 就可以唯一确定,所以不需要计算贡献,只需要 dp 完之后乘上 \(t\) 串的确定方式,确定方式可以直接用并查集算出来有多少个连通块,那就乘上多少个 \(c\) 就可以了。记得每多一个串 \(t\) 容斥系数乘上 \(-1\)。
做完了吗?其实没有,注意到比如 \(aa\) 这种串会在 \(ab\) 这种串中被减一次,这显然是不合法的,那我在搜完 border 之后,直接减去 border 位置是这个串超集的贡献,也就是算出来有多少个数刚好 border 位置是我要的这个串。
时间复杂度 \(O(2000nm)\)。
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 2e5, mod = 1e9 + 7;
int n, m, c, pw[maxn];
struct dsu {
int pre[55];
int fnd(int x) {
return (pre[x] == x ? x : pre[x] = fnd(pre[x]));
}
void unn(int x, int y) {
pre[fnd(x)] = fnd(y);
}
void init() {
for (int i = 1; i <= m; i++)
pre[i] = i;
}
} tree;
int get_to(int x) {
tree.init();
for (int i = 1; i < m; i++) {
if((x >> i - 1) & 1) {
for (int j = 1; j <= i; j++)
tree.unn(j, m - i + j);
}
}
for (int i = 1; i < m; i++) {
if(!((x >> i - 1) & 1)) {
x ^= (1ll << i - 1);
for (int j = 1; j <= i; j++) {
if(tree.fnd(j) != tree.fnd(m - i + j)) {
x ^= (1ll << i - 1);
break;
}
}
}
}
return x;
}
int get_val(int x) {
int res = 1; tree.init();
for (int i = 1; i < m; i++) {
if((x >> i - 1) & 1) {
for (int j = 1; j <= i; j++)
tree.unn(j, m - i + j);
}
}
for (int i = 1; i <= m; i++)
if(tree.pre[i] == i)
res = res * c % mod;
return res;
}
map<int, int> mp;
void dfs(int nw) {
if(mp[nw])
return ;
mp[nw] = get_val(nw);
for (int i = 1; i < m; i++) {
if((nw >> i - 1) & 1)
continue;
dfs(get_to(nw | (1ll << i - 1)));
}
}
int dp[205];
signed main() {
cin >> n >> m >> c;
pw[0] = 1;
for (int i = 1; i <= n + m; i++)
pw[i] = pw[i - 1] * c % mod;
dfs(0);
int ans = pw[n + m];
for (map<int, int>::iterator it1 = --mp.end(); ; it1--) {
for (map<int, int>::iterator it2 = --mp.end(); ; it2--) {
if((it1 -> first & it2 -> first) == it1 -> first && it1 -> first != it2 -> first)
mp[it1 -> first] = (mp[it1 -> first] - mp[it2 -> first] + mod) % mod;
if(it2 == mp.begin())
break;
}
if(it1 == mp.begin())
break;
}
for(map<int, int>::iterator it = mp.begin(); it != mp.end(); it++) {
memset(dp, 0, sizeof(dp));
dp[0] = 1;
int res = 0;
for (int i = 0; i <= n; i++) {
if(i >= m) {
for (int j = 1; j < m && i + j <= n; j++)
if((it -> first & (1ll << m - j - 1)))
dp[i + j] = (dp[i + j] - dp[i] + mod) % mod;
}
for (int j = m; i + j <= n; j++)
dp[i + j] = (dp[i + j] - dp[i] * pw[j - m] % mod + mod) % mod;
res = (res + dp[i] * pw[n - i] % mod) % mod;
}
// cout << dp[n] << endl;
ans = (ans - res * it -> second % mod + mod) % mod;
}
cout << ans << endl;
return 0;
}

浙公网安备 33010602011771号