后缀自动机题单

bzoj3473

简单的想法就是把这些串的广义\(\mathrm{SAM}\)建出来,然后对每个节点求出它代表的串出现在了多少个原串中。假设这个已经求出,接下来我们对每个节点求出它及其祖先节点的贡献(因为它们对应了最长串的一连串后缀),在求每个串的答案时在\(\mathrm{SAM}\)匹配就好了。

那么怎么求每个节点的串出现在了多少个原串中呢?暴力的想法是在\(\mathrm{SAM}\)上匹配,然后对其祖先打上标记看起来复杂度是\(O(|s|^2)\)的,但是由于\(\mathrm{SAM}\)的节点数是\(O(\sum |s|)\)的,在配合上一些不等式技巧可证得打标记的总时间复杂度是\(O(L\sqrt L)\)的(\(L=\sum |s|\))

#include<iostream>
#include<string.h>
#include<string>
#include<stdio.h>
#include<algorithm>
#include<vector>
#include<bitset>
#include<math.h>
#include<stack>
#include<queue>
#include<set>
#include<map>
using namespace std;
typedef long long ll;
typedef long double db;
typedef pair<int,int> pii;
const int N=100000+100;
const db pi=acos(-1.0);
#define lowbit(x) (x)&(-x)
#define sqr(x) (x)*(x)
#define rep(i,a,b) for (register int i=a;i<=b;i++)
#define per(i,a,b) for (register int i=a;i>=b;i--)
#define go(u,i) for (register int i=head[u];i;i=sq[i].nxt)
#define fir first
#define sec second
#define mp make_pair
#define pb push_back
#define maxd 998244353
#define eps 1e-8
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
    while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
    return x*f;
}

int n,m,lst,tot=1,siz[N<<1],ch[N<<1][26],len[N<<1],sum[N<<1],ord[N<<1],fa[N<<1],tax[N<<1],cnt[N<<1],col[N<<1];
char s[N];
vector<char> str[N];

void insert(int x)
{
	if ((ch[lst][x]) && (len[ch[lst][x]]==len[lst]+1))
	{
		lst=ch[lst][x];
		return;
	}
	int np=(++tot),p=lst,flag=0;len[np]=len[p]+1;
	while ((p) && (!ch[p][x])) {ch[p][x]=np;p=fa[p];}
	if (!p) fa[np]=1;
	else
	{
		int q=ch[p][x];
		if (len[q]==len[p]+1) fa[np]=q;
		else
		{
			if (len[np]==len[p]+1) flag=1;
			int nq=(++tot);len[nq]=len[p]+1;
			memcpy(ch[nq],ch[q],sizeof(ch[nq]));
			fa[nq]=fa[q];fa[np]=fa[q]=nq;
			while ((p) && (ch[p][x]==q)) {ch[p][x]=nq;p=fa[p];}
			if (flag) np=nq;
		}
	}
	siz[np]=1;lst=np;
}

int main()
{
	n=read();m=read();
	rep(i,1,n)
	{
		scanf("%s",s+1);
		int len=strlen(s+1);lst=1;
		rep(j,1,len) 
		{
			insert(s[j]-'a');
			str[i].pb(s[j]);
		}
	}
	rep(i,1,tot) tax[len[i]]++;
	rep(i,1,tot) tax[i]+=tax[i-1];
	per(i,tot,1) ord[tax[len[i]]--]=i;
	rep(i,1,n)
	{
		int now=1,len=str[i].size();
		rep(j,0,len-1)
		{
			int x=str[i][j]-'a';
			now=ch[now][x];
			int tmp=now;
			while ((tmp) && (col[tmp]!=i))
			{
				col[tmp]=i;cnt[tmp]++;
				tmp=fa[tmp];
			}
		}
	}
	cnt[1]=0;
	rep(i,1,tot)
	{
		int u=ord[i],f=fa[u];
		sum[u]=sum[f];
		if (cnt[u]>=m) sum[u]+=len[u]-len[f];
	}
	rep(i,1,n)
	{
		ll ans=0;int now=1,len=str[i].size();
		rep(j,0,len-1)
		{
			int x=str[i][j]-'a';
			now=ch[now][x];
			ans+=sum[now];
		}
		printf("%lld ",ans);
	}
	return 0;
}

hdu5343

最naive的想法就是将两个串的子串数目乘起来,这样显然会算重。

考虑对于每个合法串,我们让其与串\(A\)的匹配长度尽可能大,以保证不会重复计数。

这个过程可以通过在\(\mathrm{SAM}\)上进行dp来解决,在串\(B\)\(\mathrm{SAM}\)上直接跑本质不同的子串个数的dp,在串\(A\)上还要加上在这个位置终止的方案数。两个dp都可以通过枚举下一个字母来解决。

#include<iostream>
#include<string.h>
#include<string>
#include<stdio.h>
#include<algorithm>
#include<vector>
#include<bitset>
#include<math.h>
#include<stack>
#include<queue>
#include<set>
#include<map>
using namespace std;
typedef long long ll;
typedef long double db;
typedef pair<int,int> pii;
typedef unsigned long long ull;
const int N=100000+100;
const db pi=acos(-1.0);
#define lowbit(x) (x)&(-x)
#define sqr(x) (x)*(x)
#define rep(i,a,b) for (register int i=a;i<=b;i++)
#define per(i,a,b) for (register int i=a;i>=b;i--)
#define go(u,i) for (register int i=head[u];i;i=sq[i].nxt)
#define fir first
#define sec second
#define mp make_pair
#define pb push_back
#define maxd 998244353
#define eps 1e-8
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
    while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
    return x*f;
}

int n,m;
ull f[N<<1],g[N<<1];
char s[N],t[N];
bool visf[N<<1],visg[N<<1];

struct Suffix_Automaton{
    int tot,lst,fa[N<<1],ch[N<<1][26],len[N<<1];
    
    Suffix_Automaton() {tot=lst=1;}
    
    void insert(int x)
    {
        int np=(++tot),p=lst;lst=np;len[np]=len[p]+1;
        memset(ch[np],0,sizeof(ch[np]));
        while ((p) && (!ch[p][x])) {ch[p][x]=np;p=fa[p];}
        if (!p) {fa[np]=1;return;}
        int q=ch[p][x];
        if (len[q]==len[p]+1) {fa[np]=q;return;}
        int nq=(++tot);len[nq]=len[p]+1;
        memcpy(ch[nq],ch[q],sizeof(ch[nq]));
        fa[nq]=fa[q];fa[np]=fa[q]=nq;
        while ((p) && (ch[p][x]==q)) {ch[p][x]=nq;p=fa[p];}
    }
    
    void clr()
    {
        tot=lst=1;fa[1]=0;len[1]=0;
        memset(ch[1],0,sizeof(ch[1]));
    }
}sam1,sam2;

void dfs2(int u)
{
    if (!u) return;
    if (visg[u]) return;
    g[u]=1;visg[u]=1;
    rep(i,0,25)
    {
        int v=sam2.ch[u][i];
        if (v) {dfs2(v);g[u]+=g[v];}
    }
}

ll calc(int c) {return g[sam2.ch[1][c]];} 

void dfs1(int u)
{
    if (!u) return;
    if (visf[u]) return;
    f[u]=1;visf[u]=1;
    rep(i,0,25)
    {
        int v=sam1.ch[u][i];
        if (v) {dfs1(v);f[u]+=f[v];}
        else f[u]+=calc(i);
    }
}

int main()
{
    int T=read();
    while (T--)
    {
        scanf("%s",s+1);n=strlen(s+1);
        scanf("%s",t+1);m=strlen(t+1);
        rep(i,1,n) sam1.insert(s[i]-'a');
        rep(i,1,m) sam2.insert(t[i]-'a');
        rep(i,1,sam1.tot) visf[i]=0;
        rep(i,1,sam2.tot) visg[i]=0;
        dfs2(1);dfs1(1);
        printf("%llu\n",f[1]);
        sam1.clr();sam2.clr();
    }
    return 0;
}

bzoj1396

出现次数为1的子串在\(\mathrm{SAM}\)上对应的节点显然是那些\(\mathrm{endpos}\)集合大小为\(1\)的点,对于每一个这样的点,记其\(\mathrm{endpos}\)中的元素为\(r\), 其对应的子串长度为\([mn,mx]\).

  • \(\forall p\in[r-mn+1,r]\),该节点的最短的串有可能成为\(p\)的答案。

  • \(\forall p\in[r-mx+1,r-mn]\), 串\(s[p:r]\)有可能成为\(p\)的答案。

对上面两种情况分别开一棵线段树维护即可。

#include<iostream>
#include<string.h>
#include<string>
#include<stdio.h>
#include<algorithm>
#include<vector>
#include<bitset>
#include<math.h>
#include<stack>
#include<queue>
#include<set>
#include<map>
using namespace std;
typedef long long ll;
typedef long double db;
typedef pair<int,int> pii;
const int N=100000+100;
const db pi=acos(-1.0);
#define lowbit(x) (x)&(-x)
#define sqr(x) (x)*(x)
#define rep(i,a,b) for (register int i=a;i<=b;i++)
#define per(i,a,b) for (register int i=a;i>=b;i--)
#define go(u,i) for (register int i=head[u];i;i=sq[i].nxt)
#define fir first
#define sec second
#define mp make_pair
#define pb push_back
#define maxd 998244353
#define eps 1e-8
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
    while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
    return x*f;
}

struct Segment_Tree{
	int seg[N<<2],tag[N<<2];
	
	void pushdown(int id)
	{
		if (tag[id]!=maxd)
		{
			seg[id<<1]=min(seg[id<<1],tag[id]);
			seg[id<<1|1]=min(seg[id<<1|1],tag[id]);
			tag[id<<1]=min(tag[id<<1],tag[id]);
			tag[id<<1|1]=min(tag[id<<1|1],tag[id]);
			tag[id]=maxd;
		}
	}

	void build(int id,int l,int r)
	{
		seg[id]=tag[id]=maxd;
		if (l==r) return;
		int mid=(l+r)>>1;
		build(id<<1,l,mid);build(id<<1|1,mid+1,r);
	}

	void modify(int id,int l,int r,int ql,int qr,int v)
	{
		if (ql>qr) return;
		if ((l>=ql) && (r<=qr))
		{
			seg[id]=min(seg[id],v);
			tag[id]=min(tag[id],v);
			return;
		}
		pushdown(id);int mid=(l+r)>>1;
		if (ql<=mid) modify(id<<1,l,mid,ql,qr,v);
		if (qr>mid) modify(id<<1|1,mid+1,r,ql,qr,v);
		seg[id]=min(seg[id<<1],seg[id<<1|1]);
	}

	int query(int id,int l,int r,int pos)
	{	
		if (l==r) return seg[id];
		pushdown(id);
		int mid=(l+r)>>1;
		if (pos<=mid) return query(id<<1,l,mid,pos);
		else return query(id<<1|1,mid+1,r,pos);
	}
}seg1,seg2;

int n,tot=1,lst=1,tax[N<<1],ord[N<<1],ch[N<<1][26],fa[N<<1],siz[N<<1],pos[N<<1],len[N<<1];
char s[N];

void insert(int x,int id)
{
	int np=(++tot),p=lst;lst=np;len[np]=len[p]+1;
	siz[np]=1;pos[np]=id;
	while ((p) && (!ch[p][x])) {ch[p][x]=np;p=fa[p];}
	if (!p) {fa[np]=1;return;}
	int q=ch[p][x];
	if (len[q]==len[p]+1) {fa[np]=q;return;}
	int nq=(++tot);len[nq]=len[p]+1;
	memcpy(ch[nq],ch[q],sizeof(ch[q]));
	fa[nq]=fa[q];fa[np]=fa[q]=nq;
	while ((p) && (ch[p][x]==q)) {ch[p][x]=nq;p=fa[p];}
}

int main()
{
	scanf("%s",s+1);
	n=strlen(s+1);
	rep(i,1,n) insert(s[i]-'a',i);
	seg1.build(1,1,n);seg2.build(1,1,n);
	rep(i,1,tot) tax[len[i]]++;
	rep(i,1,n) tax[i]+=tax[i-1];
	rep(i,1,tot) ord[tax[len[i]]--]=i;
	per(i,tot,1)
	{
		int u=ord[i];siz[fa[u]]+=siz[u];
		if (siz[u]!=1) continue;
		int mx=len[u],mn=len[fa[u]]+1,p=pos[u];
		seg1.modify(1,1,n,p-mn+1,p,mn);
		seg2.modify(1,1,n,p-mx+1,p-mn,p+1);
	}
	rep(i,1,n)
	{
		int ans1=seg1.query(1,1,n,i),ans2=seg2.query(1,1,n,i)-i;
		printf("%d\n",min(ans1,ans2));
	}
	return 0;
}
posted @ 2020-03-20 00:08  EncodeTalker  阅读(139)  评论(0编辑  收藏  举报