ZigZagKmp
Think twice, code once.

题意简述

给定长度为 \(n\) 的字符串 \(S,T\) ,求有多少个不同的 \(T\) 的子串 \(t\) ,满足 \(t\)\(S\) 的一个子序列。
\(1\le n\le 3000\)

算法分析

子串的个数是 \(\mathcal{O}(n^2)\) 的,子序列的个数是 \(\mathcal{O}(2^n)\) 的,因此考虑枚举所有子串,判断是否是 \(S\) 的子序列。

如何快速判断一个字符串是母串的子序列?直接上子序列自动机就好了。由于枚举过程是增量枚举的,因此总复杂度为 \(\mathcal{O}(n^2\log n)\) 或者 \(\mathcal{O}(n^2+n|\Sigma|)\) 的,取决于子序列自动机的实现方法。

但是我们枚举的子串可能有相同的,需要去重,hash即可,因为字符串总量比较大,用 双模hash 比较保险。

熟悉子序列自动机的可以跳过下面一段:

子算法1 子序列自动机

由名称,不难得出其用途。子序列自动机可以判断一个串是否是母串的子序列。

下设询问串为 \(P\) ,母串为 \(S\)

考虑这个询问串在母串上匹配的过程,假设当前询问串的前 \(i\) 位都是母串的子序列,且在母串中匹配到 \(cur\) 。形式化的讲, \(P[1:i]\)\(S\) 的子序列,且 \(P[i]=S[cur]\)

现在我们要匹配 \(P[i+1]\) ,如果能匹配上,那么 \(S\) 串在 \(cur\) 位置后一定存在一个位置 \(k\) 能匹配上,即 \(\exists k>cur\ ,\ P[i+1]=S[k]\)

但是 \(S\) 串后面可能有若干个合法的 \(k\) ,我们应该取哪一个呢?

我们应该取最靠前的那一个,即 \(k>cur\ ,\ \forall j\in (cur,k]\ ,\ P[i+1]\neq S[j]\)

为什么这样的贪心是正确的?因为这个过程有决策包容性。即我们取最靠前的符合要求的 \(k\) ,不会使得答案变差。

图片.png

后面黄色框表示如果选择 \(k_2\)\(P\) 串后面可能的一种子序列匹配,在我们选择 \(k_1\) 的时候这种后面的匹配仍然是可达的,因此不会丢失答案。

接下来有两种实现,根据不同情况应选择不同实现方法:

  1. nxt[i][c] 表示位置 \(i\) 之后第一个为 \(c\) 的字符,记录一个 lst[c] 表示当前范围内 \(c\) 最后一次的出现位置,倒序扫描一遍即可。构建时空复杂度为 \(\mathcal{O}(n|\Sigma|)\),查询时间复杂度为 \(\mathcal{O}(|P|)\)
  2. \(|\Sigma|\)vector ,存储每一种字符的出现位置,查询的时候二分位置即可,构造时空复杂度为 \(\mathcal{O}(n)\) ,查询时间复杂度为 \(\mathcal{O}(|P|\log n)\)

一般来说,对于字符集较小,查询量较大的题目,推荐使用第一种写法。对于字符集较大,或者空间较为紧张的题目,推荐使用第二种写法。

实现方法1:

int nxt[maxn][26];//假定为字符集为所有小写字符
int lst[26];
int n;
void build(char *S){
	n=strlen(S+1);
	for(int j=0;j<26;++j)lst[j]=n+1;
	for(int i=n;i>=0;--i){
		for(int j=0;j<26;++j)nxt[i][j]=lst[j];
		lst[S[i]-'a']=i;
	}
}
bool query(char *P){
	int cur=0,np=strlen(P+1);
	for(int i=1;i<=np;++i){
		cur=nxt[cur][P[i]-'a'];
		if(cur>n)return 0;
	}
	return 1;
}

实现方法2:

int n;
vector<int>ps[26];
void build(char *S){
	n=strlen(S+1);
	for(int i=1;i<=n;++i)ps[S[i]-'a'].push_back(i);
	for(int j=0;j<26;++j)ps[j].push_back(n+1);//防止越界,便于处理
}
bool query(char *P){
	int cur=0,np=strlen(P+1);
	for(int i=1;i<=np;++i){
		int nxt=*upper_bound(ps[P[i]-'a'].begin(),ps[P[i]-'a'].end(),cur);
		if(nxt>n)return 0;
		cur=nxt;
	}
	return 1;
}

能够正确写出双模HASH的可以跳过下面一段:

子算法2 HASH

可能有很多同学在初学字符串 HASH 的时候写的 HASH 是假的(错误率很高)(包括我自己)

字符串 HASH 核心思想是把字符串看作一个 BAS 进制数,因为显然存不下,考虑取模,比较常用的 BAS=\(131,13331\),常用的取模是unsigned long long自然溢出。

第一个要注意的地方是模数要足够大。由生日悖论, \(\sqrt n\) 个值域为 \([0,n)\) 的数存在相同数的概率超过 \(50\%\) ,如果模数是 int 范围的,则长度为 \(10^5\) 左右的随机字符串已经很容易产生冲突。可参见 Hash Killer II

但是我们仅使用自然溢出也会出问题,因为有对着卡的方法,参见 Hash Killer I

因此我们通过双底数/双模数的方法处理,具体的,我们取两个不同的BASMod,分别计算 HASH ,两个 HASH 都相同才认为是相同的的。

这种方法目前似乎没有很好的方法卡掉,具体可参见 Hash Killer Ⅲ

处理完错误率的问题,下面来处理效率问题。

先是构建的过程,考虑定义式(可能有多种定义,仅举一例):

\[H(S)=\sum_{i=0}^{|S|-1}BAS^i\cdot S[n-i] \]

为了方便后面计算,还应记录前缀和,即:

\[SH(i)=\sum_{j=0}^{i-1}BAS^j\cdot S[i-j] \]

暴力计算是 \(\mathcal{O}(n\log n)\) 的,这个过程可以用秦九韶算法优化:

\[H(S)=((((S[0]\cdot BAS+S[1])\cdot BAS+S[2])\cdot BAS\cdots)+S[n-1])\cdot BAS+S[n] \]

这样避免了快速幂,时间复杂度为 \(\mathcal{O}(n)\)

接下来是查询过程。我们查询子串 \(S[l:r]\) ,则对应答案为:

\[H(S[l:r])=SH(r)-SH(l-1)\cdot BAS^{r-l+1} \]

暴力实现是 \(\mathcal{O}(\log n)\) 的,我们预处理出所有的 \(BAS^k\) ,这样复杂度降为 \(\mathcal{O}(1)\)

现在的复杂度是线性的,接下来是一些常数优化和一些细节:

  • 使用 unsigned long long 自然溢出,减少取模
  • 如果是两个数相加/相减,且能保证都在 \([0,Mod)\) 范围内,可以使用减法代替取模
  • 底数和模数不能过大,应保证 \(\max\{BAS,Mod\}\times Mod< 2^{62}\),否则在乘法过程中可能会超出 long long/unsigned long long 范围
  • 推荐使用 unsigned long long 而不是 long long 。尤其注意 自然溢出不能使用 long long ,因为 long long 的溢出是 UB

有关本题的一个细节:

由于只有一个询问,去重应使用 sort+unique 或手写哈希表,map/unordered_map 常数巨大,通过此题比较困难。

代码实现

我的代码里采用的是第二种子序列自动机实现方法。

有关 HASH,我的代码没有完全做到上面的优化,且第二个模数是 int 范围的,有一定的优化空间。

#include<bits/stdc++.h>
using namespace std;
#define maxn 1000005
#define maxm 2000005
#define inf 0x3f3f3f3f
#define LL long long
#define ull unsigned long long
#define db double
#define ldb long double
#define mod 1000000007
#define eps 1e-9
#define local
void file(string s){freopen((s+".in").c_str(),"r",stdin);freopen((s+".out").c_str(),"w",stdout);}
template <typename Tp> void read(Tp &x){
	int fh=1;char c=getchar();x=0;
	while(c>'9'||c<'0'){if(c=='-'){fh=-1;}c=getchar();}
	while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+(c&15);c=getchar();}x*=fh;
}
int n,m;
char S[maxn],T[maxn];
vector<int>ps[26];

struct HS_node{
	ull hs1,hs2;
	HS_node operator +(HS_node y)const{
		return (HS_node){hs1+y.hs1,(hs2+y.hs2)%mod};
	}
	HS_node operator -(HS_node y)const{
		return (HS_node){hs1-y.hs1,(hs2-y.hs2+mod)%mod};
	}
	HS_node operator *(HS_node y)const{
		return (HS_node){hs1*y.hs1,(hs2*y.hs2)%mod};
	}
	bool operator <(HS_node y)const{
		return hs1==y.hs1?hs2<y.hs2:hs1<y.hs1;
	}
	bool operator ==(HS_node y)const{
		return hs1==y.hs1&&hs2==y.hs2;
	}
};
struct MY_Hash{
	const ull Bas1=131,Bas2=13331;
	HS_node pw[maxn],sh[maxn];
	void build(const char *str){//构建hash
		int nn=strlen(str+1);
		pw[0]=(HS_node){1,1};
		for(int i=1;i<=nn;++i)pw[i]=pw[i-1]*(HS_node){Bas1,Bas2};
		for(int i=1;i<=nn;++i)sh[i]=sh[i-1]*(HS_node){Bas1,Bas2}+(HS_node){str[i],str[i]};
	}
	HS_node get_hash(int l,int r){
		return sh[r]-sh[l-1]*pw[r-l+1];
	}
}hh;
HS_node aa[9000005];
int ans;
signed main(){
	#ifndef local
		file("block");
	#endif
	read(n);
	scanf("%s",S+1);
	scanf("%s",T+1);
	hh.build(T);
	for(int i=1;i<=n;++i)ps[S[i]-'a'].push_back(i);//子序列自动机构建
	for(int i=0;i<26;++i)ps[i].push_back(n+1);//防止超出边界,push一个终止符
	for(int i=1;i<=n;++i){
		int cur=0;
		for(int j=i;j<=n;++j){
			int nxt=*upper_bound(ps[T[j]-'a'].begin(),ps[T[j]-'a'].end(),cur);//子序列自动机的转移
			if(nxt>n)break;
			aa[++m]=hh.get_hash(i,j);
			cur=nxt;
		}
	}
	sort(aa+1,aa+m+1);
	ans=unique(aa+1,aa+m+1)-aa-1;//去重
	printf("%d\n",ans);
	fclose(stdin);
	fclose(stdout);
	return 0;
}
posted on 2021-03-27 20:27  ZigZagKmp  阅读(123)  评论(0编辑  收藏  举报