LOJ 2720 「NOI2018」你的名字——后缀自动机

题目:https://loj.ac/problem/2720

自己总是分不清 “SAM上一个点的 len[ ] ” 和 “一个串的前缀在 SAM 上匹配的 len ”。

于是原本想的 68 分做法是,求出 T 的本质不同子串个数,减去 T 在 S 的 SAM 上走的 fail 树的链并权值。SAM 上一个点的权值就是它代表的子串个数(len[ cr ] - len[ fa ])。

其实不行。因为 T 走到 S 的 SAM 的某个点,不是能匹配该点代表的所有子串,而是只作为 T 的一个前缀匹配了一个最长的 S 的子串。

考虑求 T 的本质不同子串,通过给 T 建 SAM 来求。 SAM 的每个点表示 T 的一个(一些位置)前缀的一些后缀。

 T 的 SAM 的每个点给答案的贡献不是 len[ cr ] - len[ fa ] ,要减去与 S 重合的部分。

所以就是想知道 T 的每个前缀的后缀能与 S 的子串匹配的最大长度。 T 在 S 的 SAM 上走一遍即可。

然后 T 的 SAM 上每个点找一个它代表的子串们可以对应的 T 的前缀,已知该前缀的最长后缀匹配长度是 l2[ i ] ,那么 SAM 该点的贡献就是 max( 0 , len[ cr ] - max( len[ fa ] , l2[ i ] ) ) 。

如果是 S 的一个区间的子串,就让 T 在 S 的 SAM 上走的时候,如果不能匹配在这个区间里,就一直跳 fa 。

用主席树求出 S 的 SAM 每个点的 right 集合具体是什么,当前 T 走着,想匹配 tlen+1 长度,就是看看 [ ql + tlen -1 , qr ] 里有没有元素。

注意不是一旦失败就跳 fa ,而是要在当前节点一直 tlen -- ;如果 tlen 减得和 len[ fa ] 一样,才跳 fa 。

一直 -- ,复杂度是对的。因为 T 走一步,tlen 最多加1; 如果减成0,就会退出,所以 tlen 的移动长度最多 \( 2*\sum |T| \) 。

写那个 tlen -- 的 while 循环的时候要注意一些。

注意每次给 T 建自动机之前要清空上次的数组。而且是清空 2*len 那么多。

注意线段树合并,新建的节点要继承原来的 ls 、 rs 之类的信息!

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ls Ls[cr]
#define rs Rs[cr]
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
int Mx(int a,int b){return a>b?a:b;}
int Mn(int a,int b){return a<b?a:b;}
const int N=1e6+5,M=4e7+5,K=26;
int n,lst=1,cnt=1,len[N],fa[N],go[N][K],tx[N],q[N];
int ql,qr,l2[N],tot,rt[N],Ls[M],Rs[M],sm[M];
char s[N];
namespace S1{
  int lst,cnt,len[N],fa[N],go[N][K],dy[N];
  void init(int n)
  {
    lst=cnt=1;
    for(int i=1,lm=n*2;i<=lm;i++)//////n*2!!!!!
      memset(go[i],0,sizeof go[i]);
  }
  int Ins(int w,int bh)
  {
    int p=lst,np=++cnt; lst=np;len[np]=len[p]+1;dy[np]=bh;
    for(;p&&!go[p][w];p=fa[p])go[p][w]=np;
    if(!p){fa[np]=1;return np;}
    int q=go[p][w];
    if(len[q]==len[p]+1){fa[np]=q;return np;}
    int nq=++cnt; len[nq]=len[p]+1;dy[nq]=dy[q];//
    fa[nq]=fa[q]; fa[q]=nq; fa[np]=nq;
    memcpy(go[nq],go[q],sizeof go[q]);
    for(;go[p][w]==q;p=fa[p])go[p][w]=nq;
    return np;
  }
  void solve()
  {
    int m=strlen(s+1); init(m);/////
    for(int i=1;i<=m;i++)Ins(s[i]-'a',i);
    ll ans=0;
    for(int i=2;i<=cnt;i++)
      ans+=Mx(0,len[i]-Mx(len[fa[i]],l2[dy[i]]));
    printf("%lld\n",ans);
  }
}
int Ins(int w)
{
  int p=lst,np=++cnt; lst=np;len[np]=len[p]+1;
  for(;p&&!go[p][w];p=fa[p])go[p][w]=np;
  if(!p){fa[np]=1;return np;}
  int q=go[p][w];
  if(len[q]==len[p]+1){fa[np]=q;return np;}
  int nq=++cnt; len[nq]=len[p]+1;
  fa[nq]=fa[q]; fa[q]=nq; fa[np]=nq;
  memcpy(go[nq],go[q],sizeof go[q]);
  for(;go[p][w]==q;p=fa[p])go[p][w]=nq;
  return np;
}
void Rsort()
{
  for(int i=1;i<=cnt;i++)tx[len[i]]++;
  for(int i=1;i<=n;i++)tx[i]+=tx[i-1];
  for(int i=1;i<=cnt;i++)q[tx[len[i]]--]=i;
}
int nwnd(int pr)
{
  int cr=++tot; ls=Ls[pr];rs=Rs[pr];
  sm[cr]=sm[pr]; return cr;
}
void build(int l,int r,int &cr,int p)
{
  cr=++tot; sm[cr]=1;
  if(l==r)return; int mid=l+r>>1;
  if(p<=mid)build(l,mid,ls,p);
  else build(mid+1,r,rs,p);
}
void mrg(int l,int r,int &cr,int pr)
{
  if(!cr||!pr){cr=nwnd(cr|pr); return;}
  cr=nwnd(cr); int mid=l+r>>1;//not cr=++tot!!!
  mrg(l,mid,ls,Ls[pr]); mrg(mid+1,r,rs,Rs[pr]);
  sm[cr]=sm[ls]+sm[rs];
}
int qry(int l,int r,int cr,int L,int R)
{
  if(L>R)return 0;
  if(l>=L&&r<=R)return sm[cr];
  int mid=l+r>>1;
  if(L>mid)return qry(mid+1,r,rs,L,R);
  if(R<=mid)return qry(l,mid,ls,L,R);
  return qry(l,mid,ls,L,R)+qry(mid+1,r,rs,L,R);
}
bool chk(int cr,int tlen)
{ return qry(1,n,rt[cr],ql+tlen-1,qr);}
int main()
{
  freopen("name.in","r",stdin);
  freopen("name.out","w",stdout);
  scanf("%s",s+1); n=strlen(s+1);
  for(int i=1;i<=n;i++)
    {
      int d=Ins(s[i]-'a');
      build(1,n,rt[d],i);
    }
  Rsort();
  for(int i=cnt;i>1;i--)
    mrg(1,n,rt[fa[q[i]]],rt[q[i]]);
  int Q=rdn();
  while(Q--)
    {
      scanf("%s",s+1); ql=rdn();qr=rdn();
      int cr=1,m=strlen(s+1),tlen=0;
      for(int i=1;i<=m;i++)
    {
      int w=s[i]-'a';
      while(1)
        {
          if(go[cr][w]&&chk(go[cr][w],tlen+1))
        { tlen++; cr=go[cr][w]; break;}
          if(!tlen)break;
          if(!go[cr][w])
        {cr=fa[cr];tlen=len[cr];continue;}
          tlen--; if(tlen==len[fa[cr]])cr=fa[cr];
        }
      l2[i]=tlen;
    }
      S1::solve();
    }
  return 0;
}

 

posted on 2019-05-04 20:00  Narh  阅读(136)  评论(0编辑  收藏

导航