后缀自动机练习题

版题,用来熟悉 SAM 以及其树形结构的用途。

难度都差不多,就没怎么注意排序。

SDOI2016 生成魔咒

对一个长度为 \(n\) 的字符串,每次动态地往 SAM 中插入一个字符,求每次插入之后不同子串的个数。

提一下另一种计算不同子串个数的方法。

由于一个子串必处于一个状态且仅在一个状态中,所以对子串计数即对每个状态的子串数量计数。

一个状态 \(u\) 的子串即 \(\operatorname{long}(u)\) 的一段连续后缀,一直到 \(\operatorname{short}\) 为止,即共 \(\operatorname{len}(u)-\operatorname{minlen}(u)+1\)

因为 \(\operatorname{minlen}(u)=\operatorname{len}(\operatorname{link}(u))+1\),化简一下,一个结点的贡献就是 \(\operatorname{len}(u)-\operatorname{len}(\operatorname{link}(u))\)

对每个结点计数求和即可。

回到本题,加入一个字符会多一个状态,对这个新状态求增量即可。

构造 SAM 的过程中分裂出的新节点 \(new\) 为什么不算?

因为分裂 \(new\) 的本质即把一个状态拆成两个,总和不变,只算新的即可。

#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define f(i,a,b) for(int i=(a);i>=(b);i--)
#define MIN(x,y) (x<y?x:y)
#define MAX(x,y) (x>y?x:y)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=2e5+10;
struct state
{
	int link,len;
	sd map<int,int> nex;
}st[N];
int last,siz;
void init()
{
	st[0].link=-1;
	st[0].len=0;
	siz++,last=0;
}
int extend(int c)
{
	int cur=siz++,p=last;
	st[cur].len=st[last].len+1;
	while(p!=-1&&!st[p].nex.count(c))
	{
		st[p].nex[c]=cur;
		p=st[p].link;
	}
	if(p==-1)
	{
		st[cur].link=0;
	}
	else
	{
		int q=st[p].nex[c];
		if(st[q].len==st[p].len+1)
		{
			st[cur].link=q;
		}
		else
		{
			int nw=siz++;
			st[nw].nex=st[q].nex;
			st[nw].len=st[p].len+1;
			st[nw].link=st[q].link;
			while(p!=-1&&st[p].nex[c]==q)
			{
				st[p].nex[c]=nw;
				p=st[p].link;
			}
			st[cur].link=st[q].link=nw;
		}
	}
	last=cur;
	return cur;
}
int n,ans;
void solve()
{
	n=read();
	init();
	F(i,1,n)
	{
		int x=read();
		int now=extend(x);
		ans+=st[now].len-st[st[now].link].len;
		put(ans);
	}
}
signed main()
{
	int T=1;
//	T=read();
	while(T--) solve();
    return 0;
}

TJOI2019 甲苯先生和大中锋的字符串

一开始读错题了。。。

一个状态会将 \([\operatorname{minlen}(u),\operatorname{len}(u)]\) 这一段的数量加 1,差分维护即可。

#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define f(i,a,b) for(int i=(a);i>=(b);i--)
#define MIN(x,y) (x<y?x:y)
#define MAX(x,y) (x>y?x:y)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=1e6+10;
struct state
{
	int link,len;
	sd map<int,int> nex;
}st[N];
int siz,last;
char s[N];
void init()
{
	st[0].link=-1;
	st[0].len=0;
	siz++;
	last=0;
}
int dp[N];
void extend(char c)
{
	int cur=siz++;
	dp[cur]=1;
	st[cur].len=st[last].len+1;
	int p=last;
	while(p!=-1&&!st[p].nex.count(c))
	{
		st[p].nex[c]=cur;
		p=st[p].link;
	}
	if(p==-1) st[cur].link=0;
	else
	{
		int q=st[p].nex[c];
		if(st[q].len==st[p].len+1) st[cur].link=q;
		else
		{
			int nw=siz++;
			st[nw].nex=st[q].nex;
			st[nw].link=st[q].link;
			st[nw].len=st[p].len+1;
			while(p!=-1&&st[p].nex[c]==q)
			{
				st[p].nex[c]=nw;
				p=st[p].link;
			}
			st[cur].link=st[q].link=nw;
		}
	}
	last=cur;
}
struct node
{
	int nex,to;
}a[N];
int tot,head[N];
void add(int u,int v)
{
	a[++tot].nex=head[u];
	head[u]=tot;
	a[tot].to=v;
}
void dfs(int u)
{
	for(int i=head[u];i;i=a[i].nex)
	{
		int v=a[i].to;
		dfs(v);
		dp[u]+=dp[v];
	}
}
int n,K;
int cnt[N];
void clear()
{
	F(i,0,siz-1) dp[i]=0,st[i].nex.clear(),st[i].len=head[i]=0;
	siz=0;
	tot=0;
}
void solve()
{
	clear();
	scanf("%s",s+1);
	n=strlen(s+1);K=read();
	F(i,1,n) cnt[i]=0;
	init();
	F(i,1,n) extend(s[i]);
	F(i,1,siz-1) add(st[i].link,i);
	dfs(0);
	int fl=0;
	F(i,1,siz-1) if(dp[i]==K)
	{
		fl=1;
		//len(link(i))+1~len(i)
		int l=st[st[i].link].len+1,r=st[i].len;
		cnt[l]++;
		cnt[r+1]--;
	}
	if(!fl) return put(-1);
	F(i,1,n) cnt[i]+=cnt[i-1];
	int ans=0,nice=0;
	f(i,n,1) if(cnt[i]>nice) ans=i,nice=cnt[i];
	put(ans);
}
signed main()
{
	int T=1;
	T=read();
	while(T--) solve();
    return 0;
}

CF802I Fake News (hard)

处理出 \(cnt_u\),一个 \(u\) 的贡献是其子串数量乘以 \(cnt^2_u\),随便算算即可。

#include<bits/stdc++.h>
#define sd std::
//#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define f(i,a,b) for(int i=(a);i>=(b);i--)
#define MIN(x,y) (x<y?x:y)
#define MAX(x,y) (x>y?x:y)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(long long x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(long long x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(long long x){print(x);putchar('\n');}
void printk(long long x){print(x);putchar(' ');}
const int N=2e5+10;
struct state
{
	int link,len;
	sd map<int,int> nex;
}st[N];
int siz,last;
void init()
{
	st[0].link=-1;
	st[0].len=0;
	siz++,last=0;
}
int cnt[N];
void extend(char c)
{
	int cur=siz++,p=last;
	st[cur].len=st[last].len+1;
	cnt[cur]=1;
	while(p!=-1&&!st[p].nex.count(c))
	{
		st[p].nex[c]=cur;
		p=st[p].link;
	}
	if(p==-1)
	{
		st[cur].link=0;
	}
	else
	{
		int q=st[p].nex[c];
		if(st[q].len==st[p].len+1)
		{
			st[cur].link=q;
		}
		else
		{
			int nw=siz++;
			st[nw].len=st[p].len+1;
			st[nw].link=st[q].link;
			st[nw].nex=st[q].nex;
			while(p!=-1&&st[p].nex[c]==q)
			{
				st[p].nex[c]=nw;
				p=st[p].link;
			}
			st[cur].link=st[q].link=nw;
		}
	}
	last=cur;
}
struct node
{
	int nex;
	int to;
}a[N];
int tot,head[N];
void add(int u,int v)
{
	a[++tot].nex=head[u];
	head[u]=tot;
	a[tot].to=v;
}
void dfs(int u)
{
	for(int i=head[u];i;i=a[i].nex)
	{
		int v=a[i].to;
		dfs(v);
		cnt[u]+=cnt[v];
	}
}
void clear()
{
	F(i,0,siz-1) st[i].nex.clear(),head[i]=0,cnt[i]=0;
	tot=siz=0;
	init();
}
char s[N];
int n;
void solve()
{
	clear();
	scanf("%s",s+1);
	n=strlen(s+1);
	F(i,1,n) extend(s[i]);
	F(i,1,siz-1) add(st[i].link,i);
	dfs(0);
	long long ans=0;
	F(i,1,siz-1)
	{
		int l=st[st[i].link].len+1,r=st[i].len;
		ans+=1ll*(r-l+1)*cnt[i]*cnt[i];
	}
	put(ans);
}
int main()
{
	int T=1;
	T=read();
	while(T--) solve();
    return 0;
}

APIO2014 回文串

伪广义后缀自动机练习题。

考虑将 \(s\)\(s\) 的反串拼起来,中间用特殊字符分隔。

则变为大字符串的出现过两次且 \(s\)\(s\) 的反串中都出现过的子串贡献最大值。

考虑处理出 \(cnt_u\) 代表状态 \(u\)\(\operatorname{endpos}\) 大小。

显然,实际出现次数为 \(\dfrac{cnt_u}{2}\)。注意 \(cnt_u\) 为奇数的情况,相当于在 \(s\) 或在 \(s\) 反串中多出现一次,是不计算贡献的,即 \(\left \lfloor \dfrac{cnt_u}{2} \right \rfloor\)

(这是调代码过程中出现的,其实我也不太清楚一个回文串为什么正串反串出现次数能不一样的,可能是我代码实现有问题?)

这个做法因为复制了原串,加上 SAM 本身的两倍空间,要开 \(4n\) 空间,有可能会爆。

以下是经过卡空间之后的代码,可以通过洛谷数据。

#include<bits/stdc++.h>
#define sd std::
#define F(i,a,b) for(i=(a);i<=(b);i++)
#define f(i,a,b) for(i=(a);i>=(b);i--)
#define MIN(x,y) (x<y?x:y)
#define MAX(x,y) (x>y?x:y)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
const int N=1.2e6+5;
struct state
{
	int len,link,nex[27];
}st[N];
int last,siz;
void init()
{
	st[0].link=-1;
	siz++;
}
sd bitset<N> dp[2];//记录两个
int cnt[N];
void extend(char c,int op)
{
	int cur=siz++,p=last;
	dp[op][cur]=1;
	cnt[cur]=1;
	st[cur].len=st[last].len+1;
	while(p!=-1&&!st[p].nex[c])
	{
		st[p].nex[c]=cur;
		p=st[p].link;
	}
	if(p==-1) st[cur].link=0;
	else
	{
		int q=st[p].nex[c];
		if(st[q].len==st[p].len+1)
		{
			st[cur].link=q;
		}
		else
		{
			int nw=siz++;
			st[nw].link=st[q].link;
			st[nw].len=st[p].len+1;
			for(int i=0;i<26;i++) st[nw].nex[i]=st[q].nex[i];
			while(p!=-1&&st[p].nex[c]==q)
			{
				st[p].nex[c]=nw;
				p=st[p].link;
			}
			st[cur].link=st[q].link=nw;
		}
	}
	last=cur;
}
#define a st
#define to nex[1]
#define NEX nex[2]
int tot,head[N];
void add(int u,int v)
{
	a[++tot].NEX=head[u];
	head[u]=tot;
	a[tot].to=v;
}
void dfs(int u)
{
	for(int i=head[u];i;i=a[i].NEX)
	{
		int v=a[i].to;
		dfs(v);
		dp[0][u]=dp[0][u]|dp[0][v];
		dp[1][u]=dp[1][u]|dp[1][v];
		cnt[u]+=cnt[v];
	}
}
int n,i;
char s[N];
void solve()
{
	scanf("%s",s+1);
	n=strlen(s+1);
	init();
	F(i,1,n) extend(s[i]-'a',0);
	extend(26,0);
	f(i,n,1) extend(s[i]-'a',1);
	for(int i=1;i<=siz-1;i++) st[i].nex[0]=st[i].nex[1]=0;
	F(i,1,siz-1) add(st[i].link,i);
	dfs(0);
	long long ans=0;
	F(i,0,siz-1) if(dp[1][i]&dp[0][i]) ans=MAX(ans,1ll*cnt[i]/2*st[i].len);
	printf("%lld",ans);
}
signed main()
{
	int T=1;
//	T=read();
	while(T--) solve();
    return 0;
}

HAOI2016 找相同字符

也是伪后缀自动机。

将两个串拼起来形成大串。

记录每个串在两个分串中出现的次数 \(cnt_1\)\(cnt_2\)。稍微算算即可。

一个状态的贡献是 \(p_u\times cnt_1(u)\times cnt_2(u)\)\(p_u\) 就是这个状态内有多少个子串,这个说过怎么算的。

#include<bits/stdc++.h>
#define sd std::
//#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define f(i,a,b) for(int i=(a);i>=(b);i--)
#define MIN(x,y) (x<y?x:y)
#define MAX(x,y) (x>y?x:y)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=2e6+10;
char s[N],t[N];
int n,m;
struct state
{
	int len,link;
	sd map<int,int> nex;
}st[N];
int siz,last;
void init()
{
	st[0].link=-1;
	st[0].len=0;
	siz++,last=0;
}
int dp[N][2],cnt[N][2];
void extend(char c,int op)
{
	int cur=siz++,p=last;
	dp[cur][op]=1;
	cnt[cur][op]=1;
	st[cur].len=st[last].len+1;
	while(p!=-1&&!st[p].nex.count(c))
	{
		st[p].nex[c]=cur;
		p=st[p].link;
	}
	if(p==-1)
	{
		st[cur].link=0;
	}
	else
	{
		int q=st[p].nex[c];
		if(st[q].len==st[p].len+1)
		{
			st[cur].link=q;
		}
		else
		{
			int nw=siz++;
			st[nw].link=st[q].link;
			st[nw].len=st[p].len+1;
			st[nw].nex=st[q].nex;
			while(p!=-1&&st[p].nex[c]==q)
			{
				st[p].nex[c]=nw;
				p=st[p].link;
			}
			st[cur].link=st[q].link=nw;
		}
	}
	last=cur;
}
struct node
{
	int nex;
	int to;
}a[N<<1];
int tot,head[N];
void add(int u,int v)
{
	a[++tot].nex=head[u];
	head[u]=tot;
	a[tot].to=v;
}
void dfs(int u)
{
	for(int i=head[u];i;i=a[i].nex)
	{
		int v=a[i].to;
		dfs(v);
		dp[u][1]|=dp[v][1];
		dp[u][0]|=dp[v][0];
		cnt[u][1]+=cnt[v][1];
		cnt[u][0]+=cnt[v][0];
	}
}
void solve()
{
	scanf("%s%s",s+1,t+1);
	n=strlen(s+1);m=strlen(t+1);
	init();
	F(i,1,n) extend(s[i],0);
	extend('{',0);
	F(i,1,m) extend(t[i],1);
	F(i,1,siz-1) add(st[i].link,i);
	dfs(0);
	long long ans=0;
	F(i,1,siz-1) if(dp[i][0]&&dp[i][1])
	{
		int l=st[st[i].link].len+1,r=st[i].len;
		ans+=1ll*(r-l+1)*cnt[i][0]*cnt[i][1];
	}
	printf("%lld",ans);
}
int main()
{
	int T=1;
//	T=read();
	while(T--) solve();
    return 0;
}
posted @ 2024-12-28 15:41  _E_M_T  阅读(23)  评论(0)    收藏  举报