浅谈 KMP
本篇文章同步发表在洛谷博客。
问题引入
给定一个字符串 \(s\) 和一个字符串 \(t\),问 \(s\) 中有哪些子串为 \(t\)。
当 \(1 \le |s|,|t| \le 1000\) 的时候,这就是一个特别简单的题目,我们可以暴力枚举 \(s\) 的每个长度为 \(|t|\) 的子串,然后取出来看看是否等于 \(t\)。这样的时间复杂度是 \(O(|s| \times |t|)\) 的,在这个数据范围下可谓是轻而易举。
或也可以使用 string 中自带的 find 函数,时间复杂度同样 \(O(|s| \times |t|)\)。注意,find 的时间复杂度最高会到平方级别,并不是什么线性或者甚至说是 \(O(1)\) 的哦。
但是这里讨论的都是小规模数据。当 \(1 \le |s|,|t| \le 10^5\) 的时候,这些方法就都用不了了,我们又该如何解决这个问题呢?
什么是 KMP?
刚才提到当 \(|s|\) 和 \(|t|\) 的规模到了 \(10^5\) 的情况下该怎么办。这时候刚才提到的两种朴素方法就都会超时了。
在这个情况下,KMP 就是一个非常好的选择!
是的,你可能会说 Hash 也行,没错,但是 Hash 再怎么说也存在一定的冲突概率,容易被卡,不是很好。而 KMP 却是一个非常保险的算法,当然是不可能像 Hash 一样出现什么哈希冲突的情况啦。
KMP 的时间复杂度是 \(O(|s|+|t|)\) 的,是不是非常便捷呢?
如何实现 KMP?
我们首先来看看暴力匹配的过程,能不能进行一些实现细节上的优化。
我们是一直往后找,很多时候的匹配做的都是无用功,要是可以延用之前比较出的结果来加快匹配速度就好了。
其实很好办的呐!我们依然是匹配,但不一下子匹配一整个串儿了,咱先就一个一个字符匹配——匹配到出现错误了,匹配不上了,哎,它不对应了,这个时候咱该干什么呢?按照常规套路,是不是应该把序列往后移一格,然后继续从第一个字符开始匹配呀。但是这里呢咱换种思想,我们考虑当前匹配的这部分串儿的一个 border——这东西表示最长公共前后缀,比如说字符串 \(\text{abcab}\),最长公共前后缀就是 \(2\),也就是这个 \(\text{ab}\),瞧 ${\color{red}\text{ab}}\text{c}{\color{red}\text{ab}} $,它在前后都出现啦。这个 border 要干啥呢?求出这个 border 的长度,然后把整个字符串位移到 border 在后缀出现的那部分,作为新匹配串的前缀,然后再继续往下匹配就成了。
于是现在的问题就变成了如何求出这个最长公共前后缀,也就是这所谓的 border。
设 \(nxt_i\) 表示 \(t\) 的前 \(i\) 个字符构成的前缀串儿的 border 的长度,当然是不能算上自己本身这个串的,不然就无意义了。
咋求?首先肯定直接从 \(2\) 开始遍历,因为长度为 \(1\) 的字符串——哈哈,就是一个字符嘛——是没有 border 的,或者说它的长度是 \(0\),因为不能算上自己。从 \(2\) 开始遍历,首先看,如果你可以直接和上一个匹配,那咱就直接匹配,那么 \(nxt_i\) 的答案就是 \(nxt_{i-1} + 1\) 了。但要是匹配不上了,那 \(nxt_i\) 的答案又是什么呢?在这里就有一个很厉害的做法——咱去求 border 的 border,找不到就再去求 border!从 \(i-1\) 开始一个类似递归的形式,只要匹配不上,就跳转到 border 页面,然后再判,再跳,直到找到。当然也有直到最后都还找不到的情况,这个时候 \(nxt_i\) 就是 \(0\) 了。
求完这个东西就可以根据上面那个经过优化的匹配思路去匹配啦,这样就实现完了,但是当你有了求 border 这件武器之后,你又发现了一种全新的匹配方式!把 \(t\) 和 \(s\) 拼接起来,中间插个无意义字符如 #,连成一个更长的大字符串,然后对这个大字符串求 border,得到一个 \(nxt\) 序列。然后让 \(i\) 从 \(|t|+2\) 开始找(我这里全部从 \(1\) 开始编号),直到 \(|t|+|s|+1\),如果哪里 \(nxt_i = |t|\),是不是说明这里就匹配到了呢?于是这样就比像刚才那样再去费尽心思匹配要轻松多了。
这就是 KMP 啦,是不是很厉害呢?
KMP 模版代码
放个代码供参考。
#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 2e6+5;
string s,t,p;
int nxt[N],n,m,k;
int read(){
int su=0,pp=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
return su*pp;
}
int main(){
cin>>s>>t;p=t+'#'+s;
n=s.size(),m=t.size(),k=p.size();
s=" "+s,t=" "+t,p=" "+p;
for(int i=2;i<=k;i++){
int j=i-1;
while(j>0){
if(p[nxt[j]+1]==p[i])
{nxt[i]=nxt[j]+1;break;}
j=nxt[j];
}
}for(int i=m+2;i<=k;i++)
if(nxt[i]==m)cout<<i-2*m<<"\n";
for(int i=1;i<=m;i++)
cout<<nxt[i]<<" ";cout<<"\n";
return 0;
}
例题选讲
这边选择几道比较经典的例题进行讲解。
P4391 Radio Transmission 无线传输
结论题,答案就是 \(n-nxt_n\),原理不用过多解释,因为就是一个很直觉性的,随便画个图就能知道。和板题代码几乎一样。
P9606 ABB
这个也是比较板子的题目,只需要取原字符串翻转后的结果,让它和原字符串拼接在一起,中间加个无意义字符,做 border,最后让 \(n\) 减掉那个 \(nxt_L\) 即可(\(L\) 表示拼接后的字符串的长度)。
CF1137B Camp Schedule
跟 KMP 关系不大,重点是要能熟练运用 border 的求解方式。
由于这个题目是给定你一个字符串 \(s\) 然后要求你重排它,使得里面出现的 \(t\) 的次数尽可能多。而且它保证了一个特别重要的性质,那就是不论是 \(s\) 还是 \(t\) 都只由 0 和 1 组成。
首先可以考虑把 \(s\) 拆一下,拆成 \(s0\) 个 0 和 \(s1\) 个 1,到时候直接重组便可。当然 \(t\) 也是要拆的,拆成 \(t0\) 个 0 和 \(t1\) 个 1。
这个时候,我们要先对 \(t\) 求一趟 border,得出这个对应的 \(nxt\) 数组。拿到这个 \(nxt\) 数组之后,我们取出 \(nxt_{|t|}\) 的值,并让 \(|t|\) 减去它,这就是在 \(t\) 后为再产生一个 \(t\) 而所需要花费的总长度。当然了,要从 \(t\) 中真的提取出这一段字符串内容,同样也对其进行拆分,拆成 \(w0\) 个 0 和 \(w1\) 个 1,方便后面的判断。
想想看,现在我们什么都有了,究竟要怎么构造才是最优的?显而易见,当 \(s\) 的 0 和 1 的个数都是够用的情况下,我们首先拼出一个完整的 \(t\),接着不断按照 border 求解之后得到的情况进行拼接,尽量凑出尽可能多的 \(t\)。直到某个时候 0 或者 1 有哪个不够了,那么也就没法子再多拼出一个 \(t\) 了,这个时候把剩余的 0 和 1 随意拼接到串末即可。
这样是不是就结束了?是不是很简单呢?
放个代码。后面的拼接部分其实也可以除法求出最多能拼多少个,不过我这里用的是循环暴力枚举的一个方式,也是一样的啦。
#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 1e6+5;
int n,m,s0,s1,t0,t1,w0,w1,nxt[N];
string s,t,sub,ans;
int read(){
int su=0,pp=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
return su*pp;
}
int main(){
cin>>s>>t;
n=s.size(),m=t.size();
s=" "+s,t=" "+t;
for(int i=1;i<=n;i++)
if(s[i]=='1')s1++;else s0++;
for(int i=1;i<=m;i++)
if(t[i]=='1')t1++;else t0++;
for(int i=2;i<=m;i++){
int h=i-1;
while(h>0){
if(t[nxt[h]+1]==t[i])
{nxt[i]=nxt[h]+1;break;}
h=nxt[h];
}
}for(int i=1;i<=nxt[m];i++)
if(t[i]=='1')w1++;else w0++;
if(m>n){for(int i=1;i<=n;i++)cout<<s[i];return 0;}
if(s0<t0||s1<t1){for(int i=1;i<=n;i++)cout<<s[i];return 0;}
ans=t;s0-=t0,s1-=t1;
for(int i=nxt[m]+1;i<=m;i++)sub+=t[i];
for(int i=m+1;i<=n;i++){
if(s0<t0-w0||s1<t1-w1)break;
s0-=(t0-w0),s1-=(t1-w1),ans+=sub;
}while(s0--)ans+="0";while(s1--)ans+="1";
for(int i=1;i<=n;i++)cout<<ans[i];cout<<"\n";
return 0;
}
CF1200E Compress Words
很明显这个东西也是要进行拼凑,但是要把重复的地方去掉。
就是要找到两个字符串重复的地方嘛,靠前的字符串的后缀,以及靠后的字符串的前缀。这不也可以用 border 实现吗?翻转一下,拼接,不就可以求了吗?
呃,但是这个东西好像直接干会超时……是的没错,因为这样子的时间复杂度是 \(O(n \sum |s|)\) 的!大概是 \(10^{11}\) 的级别,绝对接受不了。实在是因为这个 \(ans\) 太长了啊,让它一直不停地去和各个 \(s_i\) 算 border,一次又一次,不超时才怪呢!
怎么办,怎么办呢?注意到不论这个 \(ans\) 有多长,最极端的情况也莫过于这个 \(s_i\) 完全融合进原来的 \(ans\),也就是说这个 border,换句话说求出来的这个 \(nxt_k\)(其中 \(k\) 表示 \(ans\) 和 \(s_i\) 的拼接字符串的总长度)为 \(|s_i|\) 嘛!这是最大的情况了!既然 border 最大也就这样,那我们的 \(ans\) 还给这么多干嘛?去吃闲饭的吗?多浪费时间呐!于是咱就干脆只截 \(ans\) 的后 \(|s_i|\) 位去给做 border 匹配,这样的话这速度就快多了,只剩下个 \(O(\sum 2|s|)\) 了,一点点常数有什么问题嘛!
于是就结束啦,代码特别简单。
#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 1e5+5;
const int M = 1e6+5;
int n,nxt[M];
string s[N],ans;
int read(){
int su=0,pp=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
return su*pp;
}
string border(string s,string t){
string p=" "+t+"#"+s;
int k=p.size()-1;
for(int i=1;i<=k;i++)nxt[i]=0;
for(int i=2;i<=k;i++){
int j=i-1;
while(j>0){
if(p[nxt[j]+1]==p[i])
{nxt[i]=nxt[j]+1;break;}
j=nxt[j];
}
}string tmp="";
for(int i=nxt[k];i<t.size();i++)
tmp+=t[i];return tmp;
}
int main(){
n=read();
for(int i=1;i<=n;i++)cin>>s[i];
ans=s[1];
for(int i=2;i<=n;i++){
string tmp="";
int x=ans.size(),y=s[i].size();
for(int j=max(0,x-y);j<ans.size();j++)
tmp+=ans[j];
ans+=border(tmp,s[i]);
}
cout<<ans<<"\n";
return 0;
}
CF631D Messenger
发现这个东西和往常的 KMP 匹配不一样呐,这东西它多了个奇奇怪怪的封装——也难怪,按照这个规模算一下,拉长之后岂不有 \(2 \times 10^{10}\) 长?谁存的下,谁又做得来呢?不说别的,都读不进来呢!
于是只能在这个封装的基础上进行 KMP 的操作。当然了,我们要先把它弄成最简封装,换句话说,封装中不能存在相邻两个字符相同,否则就要把它们合并起来!
我们发现,封装之后的 KMP 和之前不一样了,之前是全都要相同,现在只要中间一坨全部相同,左右端点只需要字符匹配,并且个数足够即可。那么 \(m=1\) 和 \(m=2\)(所有提到的 \(n\) 和 \(m\) 均是最简封装情况下的)的需要你去特判一下,因为这俩玩意儿不存在中间一坨,玩不了 KMP。
搞完这两种特殊情况就来真的了,依然 KMP,依然求 border,不过这个时候那个 \(t\) 别给全塞进去了,头和尾塞不得,因为这是最后要判断的。搞完之后看谁的 \(nxt\) 值是 \(m-2\),是的话再判断下头和尾行不行,如果行就多一种情况啦,就可以更新答案啦!最后输出就行了。
#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 4e5+5;
struct node{LL x;char c;}a[N],b[N],p[N];
LL n,m,k,cn,cm,Ans,nxt[N];
LL read(){
LL su=0,pp=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
return su*pp;
}
bool operator == (const node &A , const node &B){
return (A.x==B.x&&A.c==B.c);
}
bool operator <= (const node &A , const node &B){
return (A.x<=B.x&&A.c==B.c);
}
int main(){
n=read(),m=read();
for(int i=1;i<=n;i++){
LL p=read();char h;cin>>h;
if(h==a[cn].c)a[cn].x+=p;else a[++cn]={p,h};
}n=cn;
for(int i=1;i<=m;i++){
LL p=read();char h;cin>>h;
if(h==b[cm].c)b[cm].x+=p;else b[++cm]={p,h};
}m=cm;
if(m==1){
for(int i=1;i<=n;i++)
if(b[1]<=a[i])Ans+=a[i].x-b[1].x+1;
}else if(m==2){
for(int i=1;i<n;i++)
if(b[1]<=a[i]&&b[2]<=a[i+1])Ans++;
}else{
for(int i=2;i<m;i++)p[++k]=b[i];
for(int i=0;i<=n;i++)p[++k]=a[i];
for(int i=2;i<=k;i++){
int j=i-1;
while(j>0){
if(p[nxt[j]+1]==p[i])
{nxt[i]=nxt[j]+1;break;}
j=nxt[j];
}
}for(int i=m;i<k;i++)
if(nxt[i]==m-2&&b[1]<=p[i-m+2]&&b[m]<=p[i+1])Ans++;
}cout<<Ans<<"\n";
return 0;
}
简单总结
KMP,它通常用来处理字符串匹配问题,可以方便快速地查找到一个字符串在另一个字符串中的出现情况。和它一起出现的是 border,它可以快速求出字符串的任意前缀的最长公共前后缀,延伸运用多用于求解两个字符串的最长公共前后缀、回文串匹配情况等,运用多种多样,是非常好用的线性算法。
码这么多字也不容易,还麻烦你留个赞支持一下,真是太感谢啦!

浙公网安备 33010602011771号