LOJ 2720 「NOI2018」你的名字——后缀自动机
题目:https://loj.ac/problem/2720
自己总是分不清 “SAM上一个点的 len[ ] ” 和 “一个串的前缀在 SAM 上匹配的 len ”。
于是原本想的 68 分做法是,求出 T 的本质不同子串个数,减去 T 在 S 的 SAM 上走的 fail 树的链并权值。SAM 上一个点的权值就是它代表的子串个数(len[ cr ] - len[ fa ])。
其实不行。因为 T 走到 S 的 SAM 的某个点,不是能匹配该点代表的所有子串,而是只作为 T 的一个前缀匹配了一个最长的 S 的子串。
考虑求 T 的本质不同子串,通过给 T 建 SAM 来求。 SAM 的每个点表示 T 的一个(一些位置)前缀的一些后缀。
T 的 SAM 的每个点给答案的贡献不是 len[ cr ] - len[ fa ] ,要减去与 S 重合的部分。
所以就是想知道 T 的每个前缀的后缀能与 S 的子串匹配的最大长度。 T 在 S 的 SAM 上走一遍即可。
然后 T 的 SAM 上每个点找一个它代表的子串们可以对应的 T 的前缀,已知该前缀的最长后缀匹配长度是 l2[ i ] ,那么 SAM 该点的贡献就是 max( 0 , len[ cr ] - max( len[ fa ] , l2[ i ] ) ) 。
如果是 S 的一个区间的子串,就让 T 在 S 的 SAM 上走的时候,如果不能匹配在这个区间里,就一直跳 fa 。
用主席树求出 S 的 SAM 每个点的 right 集合具体是什么,当前 T 走着,想匹配 tlen+1 长度,就是看看 [ ql + tlen -1 , qr ] 里有没有元素。
注意不是一旦失败就跳 fa ,而是要在当前节点一直 tlen -- ;如果 tlen 减得和 len[ fa ] 一样,才跳 fa 。
一直 -- ,复杂度是对的。因为 T 走一步,tlen 最多加1; 如果减成0,就会退出,所以 tlen 的移动长度最多 \( 2*\sum |T| \) 。
写那个 tlen -- 的 while 循环的时候要注意一些。
注意每次给 T 建自动机之前要清空上次的数组。而且是清空 2*len 那么多。
注意线段树合并,新建的节点要继承原来的 ls 、 rs 之类的信息!
#include<cstdio> #include<cstring> #include<algorithm> #define ls Ls[cr] #define rs Rs[cr] #define ll long long using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } int Mx(int a,int b){return a>b?a:b;} int Mn(int a,int b){return a<b?a:b;} const int N=1e6+5,M=4e7+5,K=26; int n,lst=1,cnt=1,len[N],fa[N],go[N][K],tx[N],q[N]; int ql,qr,l2[N],tot,rt[N],Ls[M],Rs[M],sm[M]; char s[N]; namespace S1{ int lst,cnt,len[N],fa[N],go[N][K],dy[N]; void init(int n) { lst=cnt=1; for(int i=1,lm=n*2;i<=lm;i++)//////n*2!!!!! memset(go[i],0,sizeof go[i]); } int Ins(int w,int bh) { int p=lst,np=++cnt; lst=np;len[np]=len[p]+1;dy[np]=bh; for(;p&&!go[p][w];p=fa[p])go[p][w]=np; if(!p){fa[np]=1;return np;} int q=go[p][w]; if(len[q]==len[p]+1){fa[np]=q;return np;} int nq=++cnt; len[nq]=len[p]+1;dy[nq]=dy[q];// fa[nq]=fa[q]; fa[q]=nq; fa[np]=nq; memcpy(go[nq],go[q],sizeof go[q]); for(;go[p][w]==q;p=fa[p])go[p][w]=nq; return np; } void solve() { int m=strlen(s+1); init(m);///// for(int i=1;i<=m;i++)Ins(s[i]-'a',i); ll ans=0; for(int i=2;i<=cnt;i++) ans+=Mx(0,len[i]-Mx(len[fa[i]],l2[dy[i]])); printf("%lld\n",ans); } } int Ins(int w) { int p=lst,np=++cnt; lst=np;len[np]=len[p]+1; for(;p&&!go[p][w];p=fa[p])go[p][w]=np; if(!p){fa[np]=1;return np;} int q=go[p][w]; if(len[q]==len[p]+1){fa[np]=q;return np;} int nq=++cnt; len[nq]=len[p]+1; fa[nq]=fa[q]; fa[q]=nq; fa[np]=nq; memcpy(go[nq],go[q],sizeof go[q]); for(;go[p][w]==q;p=fa[p])go[p][w]=nq; return np; } void Rsort() { for(int i=1;i<=cnt;i++)tx[len[i]]++; for(int i=1;i<=n;i++)tx[i]+=tx[i-1]; for(int i=1;i<=cnt;i++)q[tx[len[i]]--]=i; } int nwnd(int pr) { int cr=++tot; ls=Ls[pr];rs=Rs[pr]; sm[cr]=sm[pr]; return cr; } void build(int l,int r,int &cr,int p) { cr=++tot; sm[cr]=1; if(l==r)return; int mid=l+r>>1; if(p<=mid)build(l,mid,ls,p); else build(mid+1,r,rs,p); } void mrg(int l,int r,int &cr,int pr) { if(!cr||!pr){cr=nwnd(cr|pr); return;} cr=nwnd(cr); int mid=l+r>>1;//not cr=++tot!!! mrg(l,mid,ls,Ls[pr]); mrg(mid+1,r,rs,Rs[pr]); sm[cr]=sm[ls]+sm[rs]; } int qry(int l,int r,int cr,int L,int R) { if(L>R)return 0; if(l>=L&&r<=R)return sm[cr]; int mid=l+r>>1; if(L>mid)return qry(mid+1,r,rs,L,R); if(R<=mid)return qry(l,mid,ls,L,R); return qry(l,mid,ls,L,R)+qry(mid+1,r,rs,L,R); } bool chk(int cr,int tlen) { return qry(1,n,rt[cr],ql+tlen-1,qr);} int main() { freopen("name.in","r",stdin); freopen("name.out","w",stdout); scanf("%s",s+1); n=strlen(s+1); for(int i=1;i<=n;i++) { int d=Ins(s[i]-'a'); build(1,n,rt[d],i); } Rsort(); for(int i=cnt;i>1;i--) mrg(1,n,rt[fa[q[i]]],rt[q[i]]); int Q=rdn(); while(Q--) { scanf("%s",s+1); ql=rdn();qr=rdn(); int cr=1,m=strlen(s+1),tlen=0; for(int i=1;i<=m;i++) { int w=s[i]-'a'; while(1) { if(go[cr][w]&&chk(go[cr][w],tlen+1)) { tlen++; cr=go[cr][w]; break;} if(!tlen)break; if(!go[cr][w]) {cr=fa[cr];tlen=len[cr];continue;} tlen--; if(tlen==len[fa[cr]])cr=fa[cr]; } l2[i]=tlen; } S1::solve(); } return 0; }