P6216 回文匹配

回文匹配

题目描述

对于一对字符串 \((s_1,s_2)\),若 \(s_1\) 的长度为奇数的子串 \((l,r)\) 满足 \((l,r)\) 是回文的,那么 \(s_1\) 的“分数”会增加 \(s_2\)\((l,r)\) 中出现的次数。

现在给出一对 \((s_1,s_2)\),请计算出 \(s_1\) 的“分数”。

答案对 \(2 ^ {32}\) 取模。

输入格式

第一行两个整数,\(n,m\),表示 \(s_1\) 的长度和 \(s_2\) 的长度。

第二行两个字符串,\(s_1,s_2\)

输出格式

一行一个整数,表示 \(s_1\) 的分数。

样例 #1

样例输入 #1

10 2
ccbccbbcbb bc

样例输出 #1

4

样例 #2

样例输入 #2

20 2
cbcaacabcbacbbabacca ba

样例输出 #2

4

提示

【样例解释】

对于样例一:

子串 \((1,5)\)\(s_2\) 出现了一次,子串 \((2,4)\)\(s_2\) 出现了一次。

子串 \((7,9)\)\(s_2\) 出现了一次,子串 \((6,10)\)\(s_2\) 出现了一次。


【数据范围】

本题采用捆绑测试。

  • 对于 \(100\%\) 的数据:\(1 \le n,m \le 3 \times 10 ^ 6\),字符串中的字符都是小写字母。

  • 详细的数据范围:

    Subtask 编号 \(n,m \le\) 分值
    \(1\) \(100\) \(15\)
    \(2\) \(10 ^ 3\) \(15\)
    \(3\) \(5 \times 10 ^ 3\) \(20\)
    \(4\) \(4 \times 10 ^ 5\) \(30\)
    \(5\) \(3 \times 10 ^ 6\) \(20\)

分析:

  1. KMP + Manacher + 前缀和
    KMP标记模式串在匹配串中出现的下标,Manacher处理出每个位置的最大回文半径,前缀和统计答案(想到这已经完成了 10%

  2. 本题的精华在于如何在 O(n) 求出模式串在各回文区间内的贡献值

  • 考虑模式串长度为 1 时的答案统计:
    在回文区间[L, R]内,当模式串匹配成功一次( cnt[i]++ ),就会在当前位置有一次贡献,在以后的对称位置也会有一次贡献(这一步可以在计算到后面时进行统计),用 n 步前缀和计算出区间贡献值:\(\sum_{L}^{mid}cnt[i] * (i - L + 1) + \sum_{mid + 1}^{R}cnt[i] * (R - i + 1)\)
  • 拓展出去,当模式串长度为 m 时,使用KMP处理好的匹配成功的下标( cnt[i]++ ),不同的是 [L, R] 不是从 [i - r[i] + 1, i + r[i ] - 1】 计算贡献,应当考虑的区间是 [i - r[i], i + r[i] - m] ,此外当区间长度小于 m 时直接continue

实现:

int n, m;
string a, b;
int r[N], cnt[N];
int ne[N];
void kmp(string s, string p)
{
    for (int i = 2, j = 0; i <= m; i++)
    {
        while (j && p[i] != p[j + 1])
            j = ne[j];
        if (p[i] == p[j + 1])
            j++;
        ne[i] = j;
    }

    for (int i = 1, j = 0; i <= n; i++)
    {
        while (j && s[i] != p[j + 1])
            j = ne[j];
        if (s[i] == p[j + 1])
            j++;
        if (j == m)
        {
            cnt[i - m + 1]++;
            j = ne[j];
        }
    }
}
void manacher(string s)
{
    int p = 1, mx = 1;
    for (int i = 1; i <= n; i++)
    {
        r[i] = min(mx - i, r[(p << 1) - i]);
        while (s[i - r[i]] == s[i + r[i]])
            r[i]++;
        if (i + r[i] > mx)
        {
            mx = i + r[i];
            p = i;
        }
    }
}
void cal()
{
    for (int i = 1; i <= n; i++)
        cnt[i] += cnt[i - 1];
    for (int i = 1; i <= n; i++)
        cnt[i] += cnt[i - 1];

    int res = 0;
    for (int i = 1; i <= n; i++)
    {
        if (2 * r[i] - 1 < m)
            continue;
        int ll = i - r[i], rr = i + r[i] - m;
        int mid = ll + rr >> 1;
        res = (((res + cnt[rr]) % MOD + cnt[ll - 1]) % MOD - cnt[mid] + MOD) % MOD;
        if ((ll + rr) & 1)
            res -= cnt[mid];
        else
            res -= cnt[mid - 1];
    }
    cout << res % MOD << endl;
}
void solve()
{
    cin >> n >> m >> a >> b;
    a = "$" + a, b = "$" + b;
    kmp(a, b);
    manacher(a);
    cal();
}
posted @ 2023-04-06 14:41  347Foricher  阅读(50)  评论(0)    收藏  举报