浅谈 KMP

本篇文章同步发表在洛谷博客

问题引入

给定一个字符串 \(s\) 和一个字符串 \(t\),问 \(s\) 中有哪些子串为 \(t\)

\(1 \le |s|,|t| \le 1000\) 的时候,这就是一个特别简单的题目,我们可以暴力枚举 \(s\) 的每个长度为 \(|t|\) 的子串,然后取出来看看是否等于 \(t\)。这样的时间复杂度是 \(O(|s| \times |t|)\) 的,在这个数据范围下可谓是轻而易举。

或也可以使用 string 中自带的 find 函数,时间复杂度同样 \(O(|s| \times |t|)\)。注意,find 的时间复杂度最高会到平方级别,并不是什么线性或者甚至说是 \(O(1)\) 的哦。

但是这里讨论的都是小规模数据。当 \(1 \le |s|,|t| \le 10^5\) 的时候,这些方法就都用不了了,我们又该如何解决这个问题呢?

什么是 KMP?

刚才提到当 \(|s|\)\(|t|\) 的规模到了 \(10^5\) 的情况下该怎么办。这时候刚才提到的两种朴素方法就都会超时了。

在这个情况下,KMP 就是一个非常好的选择!

是的,你可能会说 Hash 也行,没错,但是 Hash 再怎么说也存在一定的冲突概率,容易被卡,不是很好。而 KMP 却是一个非常保险的算法,当然是不可能像 Hash 一样出现什么哈希冲突的情况啦。

KMP 的时间复杂度是 \(O(|s|+|t|)\) 的,是不是非常便捷呢?

如何实现 KMP?

我们首先来看看暴力匹配的过程,能不能进行一些实现细节上的优化。

我们是一直往后找,很多时候的匹配做的都是无用功,要是可以延用之前比较出的结果来加快匹配速度就好了。

其实很好办的呐!我们依然是匹配,但不一下子匹配一整个串儿了,咱先就一个一个字符匹配——匹配到出现错误了,匹配不上了,哎,它不对应了,这个时候咱该干什么呢?按照常规套路,是不是应该把序列往后移一格,然后继续从第一个字符开始匹配呀。但是这里呢咱换种思想,我们考虑当前匹配的这部分串儿的一个 border——这东西表示最长公共前后缀,比如说字符串 \(\text{abcab}\),最长公共前后缀就是 \(2\),也就是这个 \(\text{ab}\),瞧 ${\color{red}\text{ab}}\text{c}{\color{red}\text{ab}} $,它在前后都出现啦。这个 border 要干啥呢?求出这个 border 的长度,然后把整个字符串位移到 border 在后缀出现的那部分,作为新匹配串的前缀,然后再继续往下匹配就成了。

于是现在的问题就变成了如何求出这个最长公共前后缀,也就是这所谓的 border。

\(nxt_i\) 表示 \(t\) 的前 \(i\) 个字符构成的前缀串儿的 border 的长度,当然是不能算上自己本身这个串的,不然就无意义了。

咋求?首先肯定直接从 \(2\) 开始遍历,因为长度为 \(1\) 的字符串——哈哈,就是一个字符嘛——是没有 border 的,或者说它的长度是 \(0\),因为不能算上自己。从 \(2\) 开始遍历,首先看,如果你可以直接和上一个匹配,那咱就直接匹配,那么 \(nxt_i\) 的答案就是 \(nxt_{i-1} + 1\) 了。但要是匹配不上了,那 \(nxt_i\) 的答案又是什么呢?在这里就有一个很厉害的做法——咱去求 border 的 border,找不到就再去求 border!从 \(i-1\) 开始一个类似递归的形式,只要匹配不上,就跳转到 border 页面,然后再判,再跳,直到找到。当然也有直到最后都还找不到的情况,这个时候 \(nxt_i\) 就是 \(0\) 了。

求完这个东西就可以根据上面那个经过优化的匹配思路去匹配啦,这样就实现完了,但是当你有了求 border 这件武器之后,你又发现了一种全新的匹配方式!把 \(t\)\(s\) 拼接起来,中间插个无意义字符如 #,连成一个更长的大字符串,然后对这个大字符串求 border,得到一个 \(nxt\) 序列。然后让 \(i\)\(|t|+2\) 开始找(我这里全部从 \(1\) 开始编号),直到 \(|t|+|s|+1\),如果哪里 \(nxt_i = |t|\),是不是说明这里就匹配到了呢?于是这样就比像刚才那样再去费尽心思匹配要轻松多了。

这就是 KMP 啦,是不是很厉害呢?

KMP 模版代码

放个代码供参考。

#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 2e6+5;
string s,t,p;
int nxt[N],n,m,k;
int read(){
    int su=0,pp=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
    return su*pp;
}
int main(){
    cin>>s>>t;p=t+'#'+s;
    n=s.size(),m=t.size(),k=p.size();
    s=" "+s,t=" "+t,p=" "+p;
    for(int i=2;i<=k;i++){
        int j=i-1;
        while(j>0){
            if(p[nxt[j]+1]==p[i])
                {nxt[i]=nxt[j]+1;break;}
            j=nxt[j];
        }
    }for(int i=m+2;i<=k;i++)
        if(nxt[i]==m)cout<<i-2*m<<"\n";
    for(int i=1;i<=m;i++)
        cout<<nxt[i]<<" ";cout<<"\n";
    return 0;
}

例题选讲

这边选择几道比较经典的例题进行讲解。

P4391 Radio Transmission 无线传输

结论题,答案就是 \(n-nxt_n\),原理不用过多解释,因为就是一个很直觉性的,随便画个图就能知道。和板题代码几乎一样。

P9606 ABB

这个也是比较板子的题目,只需要取原字符串翻转后的结果,让它和原字符串拼接在一起,中间加个无意义字符,做 border,最后让 \(n\) 减掉那个 \(nxt_L\) 即可(\(L\) 表示拼接后的字符串的长度)。

CF1137B Camp Schedule

跟 KMP 关系不大,重点是要能熟练运用 border 的求解方式。

由于这个题目是给定你一个字符串 \(s\) 然后要求你重排它,使得里面出现的 \(t\) 的次数尽可能多。而且它保证了一个特别重要的性质,那就是不论是 \(s\) 还是 \(t\) 都只由 01 组成。

首先可以考虑把 \(s\) 拆一下,拆成 \(s0\)0\(s1\)1,到时候直接重组便可。当然 \(t\) 也是要拆的,拆成 \(t0\)0\(t1\)1

这个时候,我们要先对 \(t\) 求一趟 border,得出这个对应的 \(nxt\) 数组。拿到这个 \(nxt\) 数组之后,我们取出 \(nxt_{|t|}\) 的值,并让 \(|t|\) 减去它,这就是在 \(t\) 后为再产生一个 \(t\) 而所需要花费的总长度。当然了,要从 \(t\) 中真的提取出这一段字符串内容,同样也对其进行拆分,拆成 \(w0\)0\(w1\)1,方便后面的判断。

想想看,现在我们什么都有了,究竟要怎么构造才是最优的?显而易见,当 \(s\)01 的个数都是够用的情况下,我们首先拼出一个完整的 \(t\),接着不断按照 border 求解之后得到的情况进行拼接,尽量凑出尽可能多的 \(t\)。直到某个时候 0 或者 1 有哪个不够了,那么也就没法子再多拼出一个 \(t\) 了,这个时候把剩余的 01 随意拼接到串末即可。

这样是不是就结束了?是不是很简单呢?

放个代码。后面的拼接部分其实也可以除法求出最多能拼多少个,不过我这里用的是循环暴力枚举的一个方式,也是一样的啦。

#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 1e6+5;
int n,m,s0,s1,t0,t1,w0,w1,nxt[N];
string s,t,sub,ans;
int read(){
    int su=0,pp=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
    return su*pp;
}
int main(){
    cin>>s>>t;
    n=s.size(),m=t.size();
    s=" "+s,t=" "+t;
    for(int i=1;i<=n;i++)
        if(s[i]=='1')s1++;else s0++;
    for(int i=1;i<=m;i++)
        if(t[i]=='1')t1++;else t0++;
    for(int i=2;i<=m;i++){
        int h=i-1;
        while(h>0){
            if(t[nxt[h]+1]==t[i])
                {nxt[i]=nxt[h]+1;break;}
            h=nxt[h];
        }
    }for(int i=1;i<=nxt[m];i++)
        if(t[i]=='1')w1++;else w0++;
    if(m>n){for(int i=1;i<=n;i++)cout<<s[i];return 0;}
    if(s0<t0||s1<t1){for(int i=1;i<=n;i++)cout<<s[i];return 0;}
    ans=t;s0-=t0,s1-=t1;
    for(int i=nxt[m]+1;i<=m;i++)sub+=t[i];
    for(int i=m+1;i<=n;i++){
        if(s0<t0-w0||s1<t1-w1)break;
        s0-=(t0-w0),s1-=(t1-w1),ans+=sub;
    }while(s0--)ans+="0";while(s1--)ans+="1";
    for(int i=1;i<=n;i++)cout<<ans[i];cout<<"\n";
    return 0;
}

CF1200E Compress Words

很明显这个东西也是要进行拼凑,但是要把重复的地方去掉。

就是要找到两个字符串重复的地方嘛,靠前的字符串的后缀,以及靠后的字符串的前缀。这不也可以用 border 实现吗?翻转一下,拼接,不就可以求了吗?

呃,但是这个东西好像直接干会超时……是的没错,因为这样子的时间复杂度是 \(O(n \sum |s|)\) 的!大概是 \(10^{11}\) 的级别,绝对接受不了。实在是因为这个 \(ans\) 太长了啊,让它一直不停地去和各个 \(s_i\) 算 border,一次又一次,不超时才怪呢!

怎么办,怎么办呢?注意到不论这个 \(ans\) 有多长,最极端的情况也莫过于这个 \(s_i\) 完全融合进原来的 \(ans\),也就是说这个 border,换句话说求出来的这个 \(nxt_k\)(其中 \(k\) 表示 \(ans\)\(s_i\) 的拼接字符串的总长度)为 \(|s_i|\) 嘛!这是最大的情况了!既然 border 最大也就这样,那我们的 \(ans\) 还给这么多干嘛?去吃闲饭的吗?多浪费时间呐!于是咱就干脆只截 \(ans\) 的后 \(|s_i|\) 位去给做 border 匹配,这样的话这速度就快多了,只剩下个 \(O(\sum 2|s|)\) 了,一点点常数有什么问题嘛!

于是就结束啦,代码特别简单。

#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 1e5+5;
const int M = 1e6+5;
int n,nxt[M];
string s[N],ans;
int read(){
    int su=0,pp=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
    return su*pp;
}
string border(string s,string t){
    string p=" "+t+"#"+s;
    int k=p.size()-1;
    for(int i=1;i<=k;i++)nxt[i]=0;
    for(int i=2;i<=k;i++){
        int j=i-1;
        while(j>0){
            if(p[nxt[j]+1]==p[i])
                {nxt[i]=nxt[j]+1;break;}
            j=nxt[j];
        }
    }string tmp="";
    for(int i=nxt[k];i<t.size();i++)
        tmp+=t[i];return tmp;
}
int main(){
    n=read();
    for(int i=1;i<=n;i++)cin>>s[i];
    ans=s[1];
    for(int i=2;i<=n;i++){
        string tmp="";
        int x=ans.size(),y=s[i].size();
        for(int j=max(0,x-y);j<ans.size();j++)
            tmp+=ans[j];
        ans+=border(tmp,s[i]);
    }
    cout<<ans<<"\n";
    return 0;
}

CF631D Messenger

发现这个东西和往常的 KMP 匹配不一样呐,这东西它多了个奇奇怪怪的封装——也难怪,按照这个规模算一下,拉长之后岂不有 \(2 \times 10^{10}\) 长?谁存的下,谁又做得来呢?不说别的,都读不进来呢!

于是只能在这个封装的基础上进行 KMP 的操作。当然了,我们要先把它弄成最简封装,换句话说,封装中不能存在相邻两个字符相同,否则就要把它们合并起来!

我们发现,封装之后的 KMP 和之前不一样了,之前是全都要相同,现在只要中间一坨全部相同,左右端点只需要字符匹配,并且个数足够即可。那么 \(m=1\)\(m=2\)(所有提到的 \(n\)\(m\) 均是最简封装情况下的)的需要你去特判一下,因为这俩玩意儿不存在中间一坨,玩不了 KMP。

搞完这两种特殊情况就来真的了,依然 KMP,依然求 border,不过这个时候那个 \(t\) 别给全塞进去了,头和尾塞不得,因为这是最后要判断的。搞完之后看谁的 \(nxt\) 值是 \(m-2\),是的话再判断下头和尾行不行,如果行就多一种情况啦,就可以更新答案啦!最后输出就行了。

#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 4e5+5;
struct node{LL x;char c;}a[N],b[N],p[N];
LL n,m,k,cn,cm,Ans,nxt[N];
LL read(){
    LL su=0,pp=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
    return su*pp;
}
bool operator == (const node &A , const node &B){
    return (A.x==B.x&&A.c==B.c);
}
bool operator <= (const node &A , const node &B){
    return (A.x<=B.x&&A.c==B.c);
}
int main(){
    n=read(),m=read();
    for(int i=1;i<=n;i++){
        LL p=read();char h;cin>>h;
        if(h==a[cn].c)a[cn].x+=p;else a[++cn]={p,h};
    }n=cn;
    for(int i=1;i<=m;i++){
        LL p=read();char h;cin>>h;
        if(h==b[cm].c)b[cm].x+=p;else b[++cm]={p,h};
    }m=cm;
    if(m==1){
        for(int i=1;i<=n;i++)
            if(b[1]<=a[i])Ans+=a[i].x-b[1].x+1;
    }else if(m==2){
        for(int i=1;i<n;i++)
            if(b[1]<=a[i]&&b[2]<=a[i+1])Ans++;
    }else{
        for(int i=2;i<m;i++)p[++k]=b[i];
        for(int i=0;i<=n;i++)p[++k]=a[i];
        for(int i=2;i<=k;i++){
            int j=i-1;
            while(j>0){
                if(p[nxt[j]+1]==p[i])
                    {nxt[i]=nxt[j]+1;break;}
                j=nxt[j];
            }
        }for(int i=m;i<k;i++)
            if(nxt[i]==m-2&&b[1]<=p[i-m+2]&&b[m]<=p[i+1])Ans++;
    }cout<<Ans<<"\n";
    return 0;
}

简单总结

KMP,它通常用来处理字符串匹配问题,可以方便快速地查找到一个字符串在另一个字符串中的出现情况。和它一起出现的是 border,它可以快速求出字符串的任意前缀的最长公共前后缀,延伸运用多用于求解两个字符串的最长公共前后缀、回文串匹配情况等,运用多种多样,是非常好用的线性算法。

码这么多字也不容易,还麻烦你留个赞支持一下,真是太感谢啦!

posted @ 2025-11-10 21:59  嘎嘎喵  阅读(33)  评论(9)    收藏  举报