day10T1改错记
题面
有两个串\(A\)和\(B\),长度分别为\(n\)和\(m\),只含'Z','P','S','B'四个大写字母,定义\(B\)在第\(p\)位(\(0 \le p \le n - m\))匹配\(A\)为对\(B\)的每个位置\(B_i\),在\(A_{\max (0, p + i - k)}\)到\(A_{min(n - 1, p + i + k)}\)中都存在与\(B_i\)相同的字符,现给出\(A\),\(B\)和\(k\),问\(B\)在多少个位置和\(A\)匹配
题外话
怕是几天来改得最快的一道题了……
解析
只有四个字母,就考虑分开处理
枚举每个字符
设\(a_i\)表示\(A\)串的第\(i\)个位置能不能匹配这个字符,\(b_i\)表示\(B\)串的第\(i\)个位置是否是这个字符(\(a_i, b_i\)都是\(0/1\))
那么\(p\)位置能匹配的位置数量就是\(\sum_{i = 0}^{m - 1} b_i \cdot a_{p + i} = \sum_{i = 0}^{n - p - 1} b_i \cdot a_{p + i}\),上界可以扩展因为\(m - 1\)之后的位置\(b\)为\(0\),不会影响结果
设这个式子为\(f(p)\),把\(a_i\)翻转一下,即\(g(i) = a_{n - 1 - i}\),那么:
\[\begin{align}
f(p) & = \sum_{i = 0}^{n - p - 1} b_i\cdot a_{p + i} \\
& = \sum_{i = 0}^{n - p - 1} b_i \cdot g(n - p - 1 - i)
\end{align}
\]
然后把\(f\)翻转,得到:
\[f(n - 1 - p) = \sum_{i = 0}^{n - p - 1} b_i \cdot g(n - p - 1 - i)
\]
这不就是个卷积吗,\(FFT\)或者\(NTT\)都可以做
对每个字符求出各个位置能匹配的数量,显然只有当一个位置各个字符匹配数量的和为\(m\)的时候\(B\)才能在这个位置匹配\(A\),最后判一下统计答案就行了
代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
#include <algorithm>
#define MAXN 200005
typedef long long LL;
const int mod = 998244353;
const char set[] = {'Z', 'P', 'S', 'B'};
int qpower(int, int, int);
void pre_rev(int);
void NTT(int *, int, int);
int K, n, m, ans, match[MAXN], rev[MAXN << 2], f[MAXN << 2], g[MAXN << 2];
char A[MAXN], B[MAXN];
inline void inc(int &x, int y) { x += y; if (x >= mod) x -= mod; }
inline void dec(int &x, int y) { x -= y; if (x < 0) x += mod; }
inline int add(int x, int y) { x += y; return x >= mod ? x - mod : x; }
inline int sub(int x, int y) { x -= y; return x < 0 ? x + mod : x; }
int main() {
freopen("base.in", "r", stdin);
freopen("base.out", "w", stdout);
scanf("%d%s%s", &K, A, B);
n = strlen(A), m = strlen(B);
if (m > n) { puts("0"); return 0; }
for (int i = 0; i < 4; ++i) {
memset(f, 0, sizeof f); memset(g, 0, sizeof g);
for (int j = 0, last = -1; j < n; ++j) {
if (A[j] == set[i]) last = j;
if ((~last) && last >= j - K) g[j] = 1;
}
for (int j = n - 1, last = -1; j >= 0; --j) {
if (A[j] == set[i]) last = j;
if((~last) && last <= j + K) g[j] = 1;
}
for (int j = 0; j < m; ++j)
if (B[j] == set[i]) f[j] = 1;
std::reverse(g, g + n);
int sz;
for (sz = 0; (1 << sz) < (n << 1); ++sz);
pre_rev(sz);
NTT(f, sz, 1); NTT(g, sz, 1);
for (int i = 0; i < (1 << sz); ++i) f[i] = (LL)f[i] * g[i] % mod;
NTT(f, sz, -1);
for (int i = 0; i <= n - m; ++i) inc(match[i], f[n - 1 - i]);
//debug
//for (int i = 0; i <= n - m; ++i) printf("%d ", f[n - 1 - i]);
//puts("");
}
for (int i = 0; i <= n - m; ++i) if (match[i] == m) ++ans;
printf("%d\n", ans);
return 0;
}
int qpower(int x, int y, int p) {
int res = 1;
while (y) {
if (y & 1) res = (LL)res * x % p;
x = (LL)x * x % p; y >>= 1;
}
return res;
}
void NTT(int *arr, int sz, int tp) {
for (int i = 0; i < (1 << sz); ++i)
if (rev[i] > i) std::swap(arr[i], arr[rev[i]]);
for (int len = 2, half = 1; len <= (1 << sz); len <<= 1, half <<= 1) {
int wn = qpower(3, (mod - 1) / len, mod);
if (tp == -1) wn = qpower(wn, mod - 2, mod);
for (int i = 0; i < (1 << sz); i += len) {
int w = 1;
for (int j = 0; j < half; ++j, w = (LL)w * wn % mod) {
int x = arr[i + j], y = (LL)arr[i + j + half] * w % mod;
inc(arr[i + j], y); dec(arr[i + j + half] = x, y);
}
}
}
if (tp == -1) {
int inv = qpower(1 << sz, mod - 2, mod);
for (int i = 0; i < (1 << sz); ++i) arr[i] = (LL)arr[i] * inv % mod;
}
}
inline void pre_rev(int sz) {
for (int i = 0; i < (1 << sz); ++i) rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << sz - 1));
}
//Rhein_E

浙公网安备 33010602011771号