KMP 算法

算法介绍

本题解介绍 KMP 算法。

该算法求解一个字符串在另一个字符串中出现的位置。

引入

先介绍朴素的暴力做法。

暴力做法就是枚举每一个起点,依次进行配对并记录。

设字符串为 \(s1,s2\),其长度分别为 \(n,m\)。总时间复杂度是 \(O(nm)\) 的。

显然这个时间复杂度不能接受,考虑优化。

KMP 算法

先思考时间复杂度高在哪里了。仔细思考可以发现如下一种情况:

\(s1\)aaaaabaaaaab
\(s2\)aaaaaa

可以发现反复比较了,且进行的是没有意义的比较。

所以抓住 \(s2\) 的性质,才是优化的关键。

由上述例子不难发现,\(s2\) 中隐含了大量信息。如果有重复比较的部分,我们是不希望进行的。

首先有一个定义:定义 \(next\) 数组。设字符串 \(s2\)\(i\) 个字符组成的前缀 \(s'\)。则 \(next_i\)\(s'\) 的一个非 \(s'\) 本身的子串 \(t\),满足 \(t\) 既是 \(s'\) 的前缀,又是 \(s'\) 的后缀,这样的字符串 \(t\) 的最大长度。

接下来考虑如何维护,如图所示,橙色代表 \(s2\),蓝色和绿色代表当前位置。

如果开头有一个红色字符串加蓝色字符,当前位置前有红色字符串(如图),则 \(next_i\) 为红色串的长度加一。

否则,如图所示,寻找更小的串,看看是否能转移,以此类推。

接下来就可以用用双指针求解了。令 \(i\)\(s_1\) 中的指针,\(j\) 表示匹配到了哪里。

根上述过程类似,若不匹配,\(j=next_j\),寻找最大能匹配的串。注意匹配成功也要使 \(j=next_j\)

代码实现

求解 \(next\) 数组:

for (int i = 2, j = 0; i <= m; i++) {
      while (j > 0 && s2[i] != s2[j + 1])
          j = nxt[j];//缩小字符串,看看是否匹配
      if (s2[j + 1] == s2[i]) j++;
      nxt[i] = j;
  }

求出现的位置:

for (int i = 1, j = 0; i <= n; i++) {
    while (j > 0 && s1[i] != s2[j + 1])
        j = nxt[j];
    if (s2[j + 1] == s1[i]) j++;
    if (j == m) {
        cout << i - m + 1 << "\n";
        j = nxt[j];//注意,要重新从 next[j] 开始匹配
    }
}

完整代码:

#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 5;
char s1[N], s2[N];
int nxt[N], n, m;
int main () {
    ios :: sync_with_stdio (0);
    cin.tie (0), cout.tie (0);
    cin >> (s1 + 1) >> (s2 + 1);
    n = strlen (s1 + 1);
    m = strlen (s2 + 1);
    for (int i = 2, j = 0; i <= m; i++) {
        while (j > 0 && s2[i] != s2[j + 1])
            j = nxt[j];
        if (s2[j + 1] == s2[i]) j++;
        nxt[i] = j;
    }// 求 next 数组
    for (int i = 1, j = 0; i <= n; i++) {
        while (j > 0 && s1[i] != s2[j + 1])
            j = nxt[j];
        if (s2[j + 1] == s1[i]) j++;
        if (j == m) {
            cout << i - m + 1 << "\n";
            j = nxt[j];
        }
    }// 匹配字符串
//  for (int i = 1; i <= m; i++)
//      cout << nxt[i] << " ";
// 输出 next 数组,可省略
    return 0;
}

算法证明

过程证明

由于 \(next\) 数组具有传递性,所以过程是对的。

这个传递性其实是一个相对模糊的概念,所以读者需要自行体会。

时间复杂度

每一次 \(j = next_j\),都一定会使 \(j\) 减小。而 \(j\) 的上限为 \(m\),所以 while 循环总次数为 \(m\) 次。时间复杂度为 \(O(n+m)\)

posted @ 2025-11-21 18:22  暴力算法  阅读(3)  评论(0)    收藏  举报