字符串作业

KMP

学习了 KMP 的递归写法,感觉比循环写法码量小,也容易理解。

KMP 基于表示最长公共真前后缀数组 \(fail_i\) 表示前 \(i\) 个字符组成的字符串的最长公共真前后缀。最长公共真前后缀也称 border。

例如,aaa\(fail_3\)\(2\) 而不是 \(3\)

对于单模匹配,涉及函数 getnxt(x,c) 表示在位置 \(x\) 后面续上字符 \(c\) 和模式串匹配的长度多长。

递归跳出条件为 \(t_x=c\) 返回 \(x+1\)。如果不满足上述条件且 \(x=0\) 返回 \(0\) 退出。

然后递归时,即在 \(x+1\) 处失配了。考虑从 \(fail_x\) 处匹配,由于 \(fail_x\) 的意义,前面一定能匹配,再进行判断可以最大限度降低复杂度。

文本串指针不回退,复杂度 \(\mathcal O(n)\),而模式串指针最多前进 \(\mathcal O(n)\) 次,回退次数显然与前进次数量级相同,势能差不多是 \(2n\)。匹配可以做到 \(\mathcal O(n)\) 的复杂度。

\(fail\) 数组怎么求?

在模式串上,一个指针往后扫,每新来一个字符,尝试和前面的 border 拼接,否则跳 border……

这岂不是和匹配一模一样?

因此 getnxt 函数兼具两个功能。那么就要 \(\mathcal O(m)\) 的时间预处理失配数组。总时间复杂度 \(\mathcal O(m+n)\)

给出 KMP 模板题的代码,要求输出模式串每次出现第一个字符匹配的位置和 \(fail\) 数组。

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e6+10;
string s,t;
int fail[N];
int getnxt(int x,char c)
{
	if(t[x+1]==c)
	{
		return x+1;
	}
	if(x==0)
	{
		return 0;
	}
	return getnxt(fail[x],c);
}
signed main()
{
	cin>>s>>t;
	int n=s.size();
	int m=t.size();
	s=" "+s;
	t=" "+t;
	for(int i=2;i<=m;++i)
	{
		fail[i]=getnxt(fail[i-1],t[i]);
	}
	for(int i=1,x=0;i<=n;++i)
	{
		x=getnxt(x,s[i]);
		if(x==m)
		{
			cout<<i-x+1<<"\n";
		}
	}
	for(int i=1;i<=m;++i)
	{
		cout<<fail[i]<<" ";
	}
	return 0;
}

如果递归卡常了,可以无痛修改成非递归的函数。

int getnxt(int x,char c)
{
	while(x>0&&t[x+1]!=c)
	{
		x=fail[x];
	}
	if(t[x+1]==c)
	{
		return x+1;
	}
	return 0;
}

题目 1:动物园

题意

对字符串求一个 \(num\) 数组,表示前 \(i\) 个字符所有不重叠的 border 的数量。按照一定格式输出答案。

题解

考虑失配树。

可以发现,每个 \(i\) 对应唯一的 \(fail_i\),每一个 \(fail_i\) 对应唯一的 \(fail_{fail_i}\dots\)。这决定了如果以 \(fail_i\) 作为 \(i\) 的父节点,则会构成一棵树,即失配树。

失配树可以总结出以下三个性质:

  1. 祖先节点编号永远比子节点大。
  2. 任意一点 \(i\) 子树中的每个点所代表的前缀都以 \(i\) 点代表的前缀为后缀(这句很绕,但是是由 border 的性质自然发现的)。
  3. 失配树深度即该点代表前缀的 border 个数(这个应用于本题)。

回到本题,那么只需要对每个节点向上倍增,跳到第一个编号 \(\leq i\) 的节点。其深度即答案。

复杂度 \(\mathcal O(n\log n)\)

虽然 \(fail_i\) 就是 \(i\) 的父节点,可是写成显式建树的形式总会出错,不知道哪里写呲了,倍增数组隐式建树就能对。

倍增数组平时习惯把向上 \(2^j\) 这一维写在后面,但本题把这一维写在前面,竟能带来不可思议的常数优化,大约能快一秒钟。

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e6+10;
const int mod=1e9+7;
string s;
int fail[N];
int head[N];
struct lsqxx
{
	int to;
	int nxt;
}edge[N<<1];
int cntn=0;
void add_edge(int u,int v)
{
	edge[cntn].to=v;
	edge[cntn].nxt=head[u];
	head[u]=cntn++;
}
int getnxt(int x,char c)
{
	if(s[x+1]==c)
	{
		return x+1;
	}
	if(x==0)
	{
		return 0;
	}
	return getnxt(fail[x],c);
}
int f[32][N];
int num[N];
void dfs(int u,int fa)
{
	f[0][u]=fa;
	for(int i=head[u];~i;i=edge[i].nxt)
	{
		int v=edge[i].to;
		if(v==fa)
		{
			continue;
		}
		dfs(v,u);
	}
}
void _init(int n)
{
	for(int j=1;j<=30;++j)
	{
		for(int i=1;i<=n;++i)
		{
			f[j][i]=f[j-1][f[j-1][i]];
		}
	}
}
int jump(int u,int i)
{
	for(int j=30;j>=0;--j)
	{
		if(f[j][u]*2>i)
		{
			u=f[j][u];
		}
	}
	return u;
}
int c[N];
void solve()
{	
	memset(head,-1,sizeof(head));
	cntn=0;
	cin>>s;
	int n=s.size();
	s=" "+s;
	for(int i=2;i<=n;++i)
	{
		fail[i]=getnxt(fail[i-1],s[i]);
	}
	for(int i=1;i<=n;++i)
	{
		add_edge(i,fail[i]);
		add_edge(fail[i],i);
		c[i]=c[fail[i]]+1;
	}
	dfs(0,0);
	int ans=1;
	for(int i=2;i<=n;++i)
	{
		int u=i;
		u=jump(u,i);
		ans=ans*c[u];
		ans%=mod;
	}
	cout<<ans<<"\n";
}
signed main()
{
	ios::sync_with_stdio(false);
	cin.tie(0);cout.tie(0);
	int T;
	cin>>T;
	while(T--)
	{
		solve();
	}
	return 0;
}

可以用单调栈优化到 \(\mathcal O(n)\)。这是出于失配树的编号具有单调性。由于每层的节点之间大小不固定,左指针可能左右飘,但是始终是 \(\mathcal O(n)\) 的。

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e6+10;
const int mod=1e9+7;
string s;
int fail[N];
int head[N];
int ans=1;
struct lsqxx
{
	int to;
	int nxt;
}edge[N<<1];
int cntn=0;
void add_edge(int u,int v)
{
	edge[cntn].to=v;
	edge[cntn].nxt=head[u];
	head[u]=cntn++;
}
int getnxt(int x,char c)
{
	if(s[x+1]==c)
	{
		return x+1;
	}
	if(x==0)
	{
		return 0;
	}
	return getnxt(fail[x],c);
}
int stk[N];
int lt,rt;
void dfs(int u,int fa)
{
	if(u)
	{
		stk[++rt]=u;
		while((stk[lt]<<1)<=u)
		{
			++lt;
		}
		while((stk[lt]<<1)>u)
		{
			--lt;
		}
		ans*=lt;
		ans%=mod;
	}
	for(int i=head[u];~i;i=edge[i].nxt)
	{
		int v=edge[i].to;
		if(v==fa)
		{
			continue;
		}
		dfs(v,u);
	}
	--rt;
}
void solve()
{	
	memset(head,-1,sizeof(head));
	cntn=0;
	ans=1;
	rt=lt=1;
	stk[lt]=0;
	cin>>s;
	int n=s.size();
	s=" "+s;
	for(int i=2;i<=n;++i)
	{
		fail[i]=getnxt(fail[i-1],s[i]);
	}
	for(int i=1;i<=n;++i)
	{
		add_edge(i,fail[i]);
		add_edge(fail[i],i);
	}
	dfs(0,-1);
	cout<<ans<<"\n";
	
}
signed main()
{
	ios::sync_with_stdio(false);
	cin.tie(0);cout.tie(0);
	int T;
	cin>>T;
	while(T--)
	{
		solve();
	}
	return 0;
}

题解区说这个是有限状态自动机,感觉很牛但并不懂。

题目 2:【模板】失配树

题意

\(q\) 次询问字符串长度分别为 \(x\)\(y\) 的两个前缀的最长公共 border 长度。

题解

由失配树的定义,一个点所有祖先都是其 border,本题等价于查询 \(\operatorname {LCA}\)

但不尽然,如果两个点在同一个子树内,由于自身不能是一个 border,答案应该是 \(\operatorname {LCA}\) 的父节点,特判即可。

使用的是重链剖分,正好许久没打过树剖了。

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e6+10;
string s;
int fail[N];
int head[N];
struct lsqxx
{
	int to;
	int nxt;
}edge[N<<1];
int cntn=0;
void add_edge(int u,int v)
{
	edge[cntn].to=v;
	edge[cntn].nxt=head[u];
	head[u]=cntn++;
}
int getnxt(int x,char c)
{
	if(s[x+1]==c)
	{
		return x+1;
	}
	if(x==0)
	{
		return 0;
	}
	return getnxt(fail[x],c);
}
int ctt=0;
int dep[N],fa[N],siz[N],son[N];
void dfs1(int u,int f,int d)
{
	dep[u]=d;
	fa[u]=f;
	siz[u]=1;
	int maxson=-1;
	for(int i=head[u];~i;i=edge[i].nxt)
	{
		int v=edge[i].to;
		if(v==f)
		{
			continue;
		}
		dfs1(v,u,d+1);
		siz[u]+=siz[v];
		if(siz[v]>maxson)
		{
			son[u]=v;
			maxson=siz[v];
		}
	}
}
int dfn[N],top[N];
void dfs2(int u,int topf)
{
	dfn[u]=++ctt;
	top[u]=topf;
	if(!son[u])
	{
		return;
	}
	dfs2(son[u],topf);
	for(int i=head[u];~i;i=edge[i].nxt)
	{
		int v=edge[i].to;
		if(v==fa[u]||v==son[u])
		{
			continue;
		}
		dfs2(v,v);
	}
}
int lca(int x,int y)
{
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]])
		{
			swap(x,y);
		}
		x=fa[top[x]];
	}
	return dep[x]<dep[y]?x:y;
}
signed main()
{
	ios::sync_with_stdio(false);
	cin.tie(0);cout.tie(0);
	memset(head,-1,sizeof(head));
	cntn=0;
	cin>>s;
	int n=s.size();
	s=" "+s;
	for(int i=2;i<=n;++i)
	{
		fail[i]=getnxt(fail[i-1],s[i]);
	}
	for(int i=1;i<=n;++i)
	{
		add_edge(i,fail[i]);
		add_edge(fail[i],i);
	}
	dfs1(0,-1,1);
	dfs2(0,0);
	int q;
	cin>>q;
	while(q--)
	{
		int x,y;
		cin>>x>>y;
		int LCA=lca(x,y);
		if(LCA==x||LCA==y)
		{
			LCA=fa[LCA];
		}
		cout<<LCA<<"\n";
	}
	return 0;
}

题目 3:PRE-Prefixuffix

题意

称两个字符串循环相同,即将一个字符串的后缀挪到前面和另一个字符串相同。对一个字符串求最长的循环相同前后缀(要求不能重叠)。

题解

可以知道,原字符串能被写成 \(ABSBA\) 的形式。可以知道 \(A\) 是一个 border,\(B\) 是中间的一个 border。

考虑一个暴力,前后删除等量字符,暴力求中间的 border 长度。如果删掉的那些字符也是 border 就可以贡献给答案。

\(f_i\) 表示上述暴力删掉 \(i\) 个字符剩下部分最长 border 的长度,有如下不等式成立:

\[f_{i+1}+2\geq f_i \]

\(f_i\) 相当于在 \(f_{i+1}\) 的基础上在左右两边分别加了一个字符。最多就是像 bacab 左右分别插入 ab 这种恰好耦合的情况,可以使得长度 \(+2\),不可能更长了。

然后 \(f\) 就可以线性推出来(代码里又把数组滚没了)。用字符串哈希就可以解决。

自然溢出的 BKDR 似乎被卡了,选一个大一点的模数即可。

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e6+10;
const int mod=1e9+7;
const int base=998244353;
int b[N];
int h[N];
int n;
string s;
void _init()
{
	b[0]=1;
	for(int i=1;i<=n;++i)
	{
		b[i]=b[i-1]*base%mod;
	}
	for(int i=1;i<=n;++i)
	{
		h[i]=(h[i-1]*base%mod+s[i])%mod;
	}
}
int gethash(int l,int r)
{
	return (h[r]-h[l-1]*b[r-l+1]%mod+mod)%mod;
}
signed main()
{
	
	cin>>n;
	cin>>s;
	s=" "+s;
	int ans=0;
	int L=0;
	_init();
	for(int i=n/2;i>=0;--i)
	{
		L=min(L+2,n/2-i);
		while(L>0&&gethash(i+1,i+L)!=gethash(n-i-L+1,n-i))
		{
			--L;
		}
		if(gethash(1,i)==gethash(n-i+1,n))
		{
			ans=max(ans,i+L);
		}
	}
	cout<<ans;
	return 0;
}

AC 自动机

AC 自动机和 KMP 息息相关,因为 AC 自动机就是多模匹配。

大概思路就是把这些模式串组织到 Trie 上面,然后考虑 \(fail\) 数组。此时它表示节点 \(i\) 代表的前缀,在整个树上找某一个节点使得这个节点代表的前缀是节点 \(i\) 代表前缀的后缀,并且这个后缀是最长的。

其实就类似于 KMP 里面的定义,但由于树和链的不同,定义繁琐了一些。

那么 \(fail\) 数组自然是从它的父节点转移过来,实际实现中,使用的是广搜来做。

getnxt 函数也是类似 KMP 地进行实现。

\(fail\) 数组同样能构成一棵树,每次失配的时候就是不断地向父节点跳跃。

模板题给出了三种应用场景:

  1. 多少个模式串出现过。
  2. 哪些模式串出现次数最多。
  3. 每个模式串的出现次数分别是多少。

如果某个前缀匹配了,那么这个前缀的每个后缀都会匹配,也就是失配树上的每个祖先都会被匹配。记一个 \(w\) 数组用于树上差分,每次匹配上的时候就打一个标记,最后一次从下往上累加。每个模式串的结尾处的 \(w\) 值即出现的次数。

那么以上三种应用场景可以一一解决,给出场景 \(3\) 的代码。

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e6+10;
const int M=2e5+10;
int tr[M][26];
int ee[M];
int fail[M];
int idx=0;
int findnxt(int x,int c)
{
	if(tr[x][c])
	{
		return tr[x][c];
	}
	if(!x)
	{
		return 0;
	}
	return tr[x][c]=findnxt(fail[x],c);
}
int get(char x)
{
	return x-'a';
}
int pos[M];
void insert(string s,int id)
{
	int cur=0;
	int len=s.size();
	s=" "+s;
	for(int i=1;i<=len;++i)
	{
		if(tr[cur][get(s[i])]==0)
		{
			tr[cur][get(s[i])]=++idx;
		}
		cur=tr[cur][get(s[i])];
	}
	ee[cur]++;
	pos[id]=cur;
}
void build()
{
	queue<int>q;
	for(int i=0;i<26;++i)
	{
		if(tr[0][i])
		{
			fail[tr[0][i]]=0;
			q.push(tr[0][i]);
		}
	}
	while(!q.empty())
	{
		int u=q.front();
		q.pop();
		for(int i=0;i<26;++i)
		{
			int v=tr[u][i];
			if(v)
			{
				fail[v]=findnxt(fail[u],i);
				q.push(v);
			}
		}
	}
}
int head[M];
struct lsqxx
{
	int to;
	int nxt;
}edge[M];
int cnt=0;
int w[M];//失配树上计数

void add_edge(int u,int v)
{
	edge[cnt].to=v;
	edge[cnt].nxt=head[u];
	head[u]=cnt++;
}
void dfs(int u)
{
	for(int i=head[u];~i;i=edge[i].nxt)
	{
		int v=edge[i].to;
		dfs(v);
		w[u]+=w[v];
	}
}
string s[M];
signed main()
{
	memset(head,-1,sizeof(head));
	int n;
	string t;
	cin>>n;
	for(int i=1;i<=n;++i)
	{
		cin>>s[i];
		insert(s[i],i);
	}
	build();
	cin>>t;
	int lent=t.size();
	t=" "+t;
	for(int i=1;i<=idx;++i)
	{
		add_edge(fail[i],i);
	}
	for(int i=1,x=0;i<=lent;++i)
	{
		x=findnxt(x,get(t[i]));
		w[x]++;
	}
	dfs(0);
	for(int i=1;i<=n;++i)
	{
		cout<<w[pos[i]]<<"\n";
	}

	
	return 0;
}

题目:谐音替换

题意

给出一些替换规则表示 \(s_1\) 可以被替换成 \(s_2\)。然后 \(q\) 次询问 \(t_1\) 能不能用以上规则经过一次替换变成 \(t_2\)

题解

还是没弄出正解。

首先特判掉 \(s_1=s_2\)\(|t_1|\neq|t_2|\) 的情况。

不难发现,\(s_1\)\(s_2\) 无非就是中间只有一段不同。两段不同显然是不能一次替换的。设 \(s_1\)\(ACB\) 的形式,\(s_2\)\(ADB\) 的形式,那么拼接成 \(A|C?D|B\) 的形式(注意此处扩大了字符集),模式串也是同理这么拼接。

一开始是暴力遍历失配树的。但其实没必要,因为失配树是固定的。预处理一个 sum 数组表示 \(i\) 节点往上跳能跳到多少个模式串的结尾,包括 \(i\) 自身。

因为匹配本质上就是树上差分,所以 sum 数组直接能得到答案。只需要 \(\mathcal O(L_1)\) 建 AC 自动机,\(\mathcal O(L_2)\) 匹配。

感觉翻车完全是自己琢磨的不够好导致的垃圾做法。

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=2e5+10;
const int M=5e6+10;
string com(string s,string t)
{
	if(s==t)
	{
		return s;
	}
	int i=0;
	int n=s.size();
	while(i<n&&s[i]==t[i])
	{
		++i;
	}
	int j=0;
	while(j<n&&s[n-1-j]==t[n-1-j])
	{
		++j;
	}
	return s.substr(0,i)+"|"+s.substr(i,n-i-j)+"$"+t.substr(i,n-i-j)+"|"+s.substr(n-j,j);
}
string s[N];
int idx=0;
int head[M];
struct lsqxx
{
	int to;
	int nxt;
}edge[M];
int cnt=0;
void add_edge(int u,int v)
{
	edge[cnt].to=v;
	edge[cnt].nxt=head[u];
	head[u]=cnt++;
}
int tr[M][28];
int get(char x)
{
	if(x>='a'&&x<='z')
	{
		return x-'a';
	}
	if(x=='|')
	{
		return 26;
	}
	return 27;
}
int w[M];
int pos[N];
int ee[M];
void insert(string s,int id)
{
	int cur=0;
	int len=s.size();
	s=" "+s;
	for(int i=1;i<=len;++i)
	{
		if(!tr[cur][get(s[i])])
		{
			tr[cur][get(s[i])]=++idx;
		}
		cur=tr[cur][get(s[i])];
	}
	ee[cur]++;
	pos[id]=cur;
}
int fail[M];
int getnxt(int x,int c)
{
	if(tr[x][c])
	{
		return tr[x][c];
	}
	if(!x)
	{
		return 0;
	}
	return tr[x][c]=getnxt(fail[x],c);
}
void build()
{
	queue<int>q;
	for(int i=0;i<28;++i)
	{
		if(tr[0][i])
		{
			q.push(tr[0][i]);
			fail[tr[0][i]]=0;
		}
	}
	while(!q.empty())
	{
		int u=q.front();
		q.pop();
		for(int i=0;i<28;++i)
		{
			int v=tr[u][i];
			if(v)
			{
				fail[v]=getnxt(fail[u],i);
				q.push(v);
			}
		}
	}
}
int curtime=0;
int times[M];
void dfs(int u)
{
	if(times[u]!=curtime)
	{
		times[u]=curtime;
		w[u]=0;
	}
	for(int i=head[u];~i;i=edge[i].nxt)
	{
		int v=edge[i].to;
		dfs(v);
		w[u]+=w[v];
	}
}
int sum[M];
void csum()
{
	queue<int>q;
	q.push(0);
	sum[0]=ee[0];
	while(!q.empty())
	{
		int u=q.front();
		q.pop();
		for(int i=head[u];~i;i=edge[i].nxt)
		{
			int v=edge[i].to;
			sum[v]=ee[v]+sum[u];
			q.push(v);
		}
	}
}
signed main()
{
	memset(head,-1,sizeof(head));
	int n,q;
	cin>>n>>q;
	for(int i=1;i<=n;++i)
	{
		string s1,s2;
		cin>>s1>>s2;
		if(s1==s2)
		{
			n--;
			i--;
			continue;
		}
		s[i]=com(s1,s2);
//		cout<<s[i]<<"\n";
		insert(s[i],i);
	}
	build();
	for(int i=1;i<=idx;++i)
	{
		add_edge(fail[i],i);
	}
	csum();
	string t;
	while(q--)
	{
		++curtime;
		string t1,t2;
		cin>>t1>>t2;
		if(t1.size()!=t2.size())
		{
			cout<<0<<"\n";
			continue;
		}
		t=com(t1,t2);
//		cout<<t<<"\n";
		int lent=t.size();
		t=" "+t;
		int ans=0;
		for(int i=1,x=0;i<=lent;++i)
		{
			int c=get(t[i]);
			x=getnxt(x,c);
			ans+=sum[x];
		}
		cout<<ans<<"\n";
	}
	return 0;
}
posted @ 2026-02-28 23:25  Dexember  阅读(0)  评论(0)    收藏  举报