扩展KMP

这个算法主要靠画图理解,于是学习的时候画了不少图,正好写篇博客。

扩展KMP能解决的问题:
给定两个串\(S,T\),对于S的每一个后缀\(S[i...n]\)求和\(T\)\(LCP\)

\(exnxt_i\)表示后缀\(S[i...n]\)求和\(T\)\(LCP\),我们要做的就是求所有\(exnxt_i\)

我们先对\(T\)处理出\(nxt_i\)表示\(T\)的后缀\(T[i...m]\)\(T\)\(LCP\),如何处理之后再说。

1.如何求\(exnxt\)

假设我们已经求出了\([1,i-1]\)\(exnxt\),我们现在要求\(exnxt_i\)

我们维护\(p\)表示对于\([1,i-1]\)它们在\(S\)上最远和\(T\)匹配到了哪里,用式子说就是\(max_{j\in [1,i-1]}(j+extnxt_j-1)\),并且我们维护这是在哪个点取到的,记为\(p0\)

先上一张图:

我们现在是在\(p0\)匹配好的时候,我们观察\(i\)的位置,不难发现\(i\)此时对应着\(i-p0+1\)

我们设\(L=nxt_{i-p0+1}\),那么根据已知信息,即\(nxt\)的定义,我们能知道:
\(T[1...L]=T[i+p0-1...(i+p0-1)+L-1]=S[i...i+L-1]\)

即下图三线相等:

这时如果\(i+L-1<p\)(注意不能取等,我们并不知道\(p\)之后的信息),那么我们实际已经求完了,此时\(exnxt_i=L\)

代码非常简单(代码中\(r\)即为\(p\)):

if(i+nxt[i-p0+1]-1<r)exnxt[i]=nxt[i-p0+1];

接下来考虑如果\(i+L-1\geqslant p\)会怎样:


这时我们只能知道如下三条线是相等的,即:
\(T[1...p-i+1]=T[i-p0+1...(i-p0+1)+(p-i+1)-1]=S[i...i+(p-i+1)-1]\)

于是我们让\(extnxt_i\)先有一个候选答案\(p-i+1\)之后,我们暴力匹配,不断扩展\(extnxt_i\)

代码是这样的(代码中的\(r\)就是\(p\)):

exnxt[i]=max(r-i+1,0);
while(s[i+exnxt[i]]==t[exnxt[i]+1])exnxt[i]++;

之后我们让\(p0=i\),更新\(p\)的值(代码中的\(r\)就是\(p\)):

p0=i,r=i+exnxt[i]-1;

感性理解的话因为\(p\)的增长是\(O(n)\)的,所以整个算法的复杂度是\(O(n)\)的。

2.如何求\(nxt\)

我们发现\(nxt\)的定义和\(exnxt\)的定义十分相像,不过一个是\(T\)\(T\)匹配,一个是\(S\)\(T\)匹配。

于是我们用同样的方法就可以求出\(nxt\):暴力算出\(nxt_1,nxt_2\),之后按照1.中的方法求即可。

模板题

说来我好像是题解中第二个从1开始数数的。。。不过扶苏的代码太神仙了,我看不懂。

code:

#include<bits/stdc++.h>
using namespace std;
const int maxn=100010;
int n,m;
int nxt[maxn],exnxt[maxn];
char s[maxn],t[maxn];
inline void getnxt()
{
	nxt[1]=m;nxt[2]=0;
	while(t[1+nxt[2]]==t[2+nxt[2]])nxt[2]++;
	for(int i=3,p0=2,r=p0+nxt[p0]-1;i<=m;i++)
	{
		if(i+nxt[i-p0+1]-1<r)nxt[i]=nxt[i-p0+1];
		else 
		{
			nxt[i]=max(r-i+1,0);
			while(t[nxt[i]+1]==t[i+nxt[i]])nxt[i]++;
			p0=i,r=i+nxt[i]-1;
		}
	}
}
inline void getexnxt()
{
	exnxt[1]=0;
	while(s[1+exnxt[1]]==t[1+exnxt[1]])exnxt[1]++;
	for(int i=2,p0=1,r=p0+exnxt[p0]-1;i<=n;i++)
	{
		if(i+nxt[i-p0+1]-1<r)exnxt[i]=nxt[i-p0+1];
		else 
		{
			exnxt[i]=max(r-i+1,0);
			while(s[i+exnxt[i]]==t[exnxt[i]+1])exnxt[i]++;
			p0=i,r=i+exnxt[i]-1;
		}
	}
}
int main()
{
	//freopen("test.in","r",stdin);
	//freopen("test.out","w",stdout); 
	scanf("%s%s",s+1,t+1);
	n=strlen(s+1);m=strlen(t+1);
	getnxt();getexnxt();
	for(int i=1;i<=m;i++)printf("%d ",nxt[i]);
	puts("");
	for(int i=1;i<=n;i++)printf("%d ",exnxt[i]);
	return 0;
}
posted @ 2019-12-20 17:36  nofind  阅读(111)  评论(0编辑  收藏  举报