P6216 回文匹配 解题报告

Description

对于一对字符串 $(s_1,s_2)$,若 $s_1$ 的长度为奇数的子串 $(l,r)$ 满足 $(l,r)$ 是回文的,那么 $s_1$ 的“分数”会增加 $s_2$ 在 $(l,r)$ 中出现的次数。

现在给出一对 $(s_1,s_2)$,请计算出 $s_1$ 的“分数”。

Solution

先考虑最暴力的做法。枚举子区间,暴力匹配,$O(n^4)$。将匹配改成 KMP,可以优化到 $O(n^3)$。优化区间的枚举方法,边扫边匹配,可以优化到 $O(n^2)$。

我们可以一开始对 $s1, s2$ 跑一遍 KMP 匹配,将匹配成功的部分的左端点 $a_i$ 设为 $1$, 其它点设为 $0$, 问题转换为区间和。

看到回文子区间,应该能想到 manachar 算法。考虑答案的计算方式,设以 $i$ 为对称点的最长回文区间为 $[l,r]$, 我们实际上就是在求 $s[l:r] + s[l+1:r-1] + s[l+2][r-2] ... + s[i:i]$ 的答案。但是会有不合法情况,对半取就是一个经典的前缀和问题。由于会有不合法的情况,我们可以一开始将 $r$ 更新为 $r - m + 1$, 然后将 $(l+r)/2$ 作为计算时新的对称点(只是计算,并不是真的对称点),就不会有不合法情况了。

CODE

#include<bits/stdc++.h>
using namespace std;
#define uint unsigned int
inline int read() {
    int x = 0, f = 1; char c = getchar();
    while (c < '0' || c > '9') {if (c == '-') f = -f; c = getchar();}
    while (c >= '0' && c <= '9') {x = (x << 3) + (x << 1) + (c ^ 48); c = getchar();}
    return x * f;
}
const int N = 3e6 + 10;
int n, m, nxt[N], a[N];
char s[N], t[N];
inline void kmp() {
    nxt[1] = 0;
    for (int i = 2, j = 0; i <= m; ++ i) {
        while (j && t[i] != t[j+1]) j = nxt[j];
        if (t[i] == t[j+1]) j ++;
        nxt[i] = j;
    }
    for (int i = 1, j = 0; i <= n; ++ i) {
        while (j && s[i] != t[j+1]) j = nxt[j];
        if (s[i] == t[j+1]) ++ j;
        if (j == m) a[i-m+1] = 1; 
    }
}
int p[N];
inline void manachar() {
    int mx = 0, id = 0;
    for (int i = 1; i <= n; ++ i) {
        if (i < mx) p[i] = min(mx - i, p[2*id-i]);
        else p[i] = 1;
        while (s[i+p[i]] == s[i-p[i]]) ++ p[i];
        if (i + p[i] - 1 > mx) mx = i + p[i] - 1, id = i; 
    }
}
uint s1[N], s2[N];
inline void solve() {
    uint ans = 0, ans2 = 0;
    for (int i = 1; i <= n; ++ i) s1[i] = s1[i-1] + a[i], s2[i] = s2[i-1] + a[i] * i;
    for (int i = 1; i <= n; ++ i) {
        int l = i - p[i] + 1, r = i + p[i] - m;
        if (l > r) continue;
        int mid = l + r >> 1;
        ans += s2[mid] - s2[l-1] - (s1[mid] - s1[l-1]) * (l - 1);
        if (r != mid) ans += (s1[r] - s1[mid]) * (r + 1) - (s2[r] - s2[mid]); 
    }
    printf("%u\n", ans);
}
int main () {
    n = read(); m = read();
    scanf ("%s", s + 1); scanf("%s", t + 1);
    s[0] = '$'; s[n+1] = t[n+1] = '#';
    kmp(); manachar(); solve();
    return 0;
}
View Code

 

posted @ 2023-01-09 21:29  LikC1606  阅读(51)  评论(0)    收藏  举报