洛谷P4770 [NOI2018]你的名字 后缀自动机+线段树合并

洛谷P4770 [NOI2018]你的名字

题意

给定一个字符串\(S\),有\(Q\)次询问,每次询问给定一个区间\([l,r]\)和一个字符串\(T\),问\(T\)中有多少本质不同的子串且不是\(S[l;r]\)的子串。

\(|S|\le 5\cdot 10^5,Q\le 10^5,\sum|T| \le 10^6\)

分析

\(S\)建后缀自动机,用线段树合并维护\(right\)集合,对\(T\)中的每个前缀\([1;i]\)\(S[l;r]\)上匹配,设最长长度为\(mx[i]\),对\(T\)建后缀自动机再按拓扑序更新一遍\(mx[fa[i]]=max(mx[fa[i]],mx[i])\),答案就是\(\sum len[i]-max(mx[i],len[fa[i]])\)。怎么在\(s[l;r]\)上匹配呢,假设在\([1;i-1]\)这个前缀上匹配的长度为\(L\),在\(S\)\(sam\)上的点为\(u\),若\(u\)存在到\(T[i]\)的转移边转移到点\(x\),且\(x\)\(right\)集合在区间\([l+L,r]\)中有一个元素,说明可以继续匹配下去,更新长度,否则\(L--\),继续尝试转移,如果\(L==len[fa[u]]\)\(u=fa[u]\)

Code

#include<bits/stdc++.h>
#define rep(i,x,n) for(int i=x;i<=n;i++)
#define per(i,n,x) for(int i=n;i>=x;i--)
#define sz(a) int(a.size())
#define rson mid+1,r,rs[p]
#define pii pair<int,int>
#define lson l,mid,ls[p]
#define ll long long
#define pb push_back
#define mp make_pair
#define se second
#define fi first
using namespace std;
const double eps=1e-8;
const int mod=1e9+7;
const int N=2e6+10;
const int M=5e5+10;
const int inf=1e9;
int n,m,q;
char s[N],t[N];
int sum[N],id[N],rt[N];
vector<int>g[N];
struct SegmentTree{
    int tr[M*40];
    int ls[M*40],rs[M*40],tot;
    void up(int x,int l,int r,int &p){
        if(!p) p=++tot;
        if(l==r){
            tr[p]=l;
            return;
        }
        int mid=l+r>>1;
        if(x<=mid) up(x,lson);
        else up(x,rson);
        tr[p]=max(tr[ls[p]],tr[rs[p]]);
    }
    int merge(int x,int y,int l,int r){
        if(!x||!y) return x+y;
        int p=++tot,mid=l+r>>1;
        if(l==r){
            tr[p]=l;
        }else{
            ls[p]=merge(ls[x],ls[y],l,mid);
            rs[p]=merge(rs[x],rs[y],mid+1,r);
            tr[p]=max(tr[ls[p]],tr[rs[p]]);
        }
        return p;
    }
    int qy(int dl,int dr,int l,int r,int p){
        if(!p||dl>dr) return 0;
        if(l==dl&&r==dr) return tr[p];
        int mid=l+r>>1;
        if(dr<=mid) return qy(dl,dr,lson);
        else if(dl>mid) return qy(dl,dr,rson);
        else return max(qy(dl,mid,lson),qy(mid+1,dr,rson));
    }
}seg;
struct SAM{
    int last,cnt;int ch[N][27],fa[N],len[N],mx[N];
    int newnode(){
        ++cnt;
        mx[cnt]=0;
        memset(ch[cnt],0,sizeof ch[cnt]);
        return cnt;
    }
    void insert(int c){
        int p=last,np=newnode();last=np;len[np]=len[p]+1;
        for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=np;
        if(!p) fa[np]=1;
        else {
            int q=ch[p][c];
            if(len[q]==len[p]+1) fa[np]=q;
            else  {
                int nq=newnode();len[nq]=len[p]+1;
                memcpy(ch[nq],ch[q],sizeof ch[q]);
                fa[nq]=fa[q],fa[q]=fa[np]=nq;
                for(;ch[p][c]==q;p=fa[p]) ch[p][c]=nq;
            }
        }
    }
    void init(){
        last=cnt=1;
        mx[cnt]=0;
        memset(ch[cnt],0,sizeof ch[cnt]);
    }
    void dfs(int u){
        for(int x:g[u]){
            dfs(x);
            rt[u]=seg.merge(rt[u],rt[x],1,n);
        }
    }
    ll gao(){
        for(int i=1;i<=cnt;i++) sum[i]=0;
        for(int i=1;i<=cnt;i++) sum[len[i]]++;
        for(int i=1;i<=cnt;i++) sum[i]+=sum[i-1];
        for(int i=1;i<=cnt;i++) id[sum[len[i]]--]=i;
        for(int i=cnt;i>=1;i--) mx[fa[id[i]]]=max(mx[fa[id[i]]],mx[id[i]]);
        ll ans=0;
        for(int i=2;i<=cnt;i++) ans+=max(0,len[i]-max(mx[i],len[fa[i]]));
        return ans;
    }
    void build(){
        for(int i=2;i<=cnt;i++) g[fa[i]].pb(i);
        dfs(1);
    }
}S,T;
void solve(int cas){
    int l,r;
    scanf("%s%d%d",t+1,&l,&r);
    m=strlen(t+1);
    T.init();
    int u=1,L=0;
    for(int i=1;i<=m;i++){
        int c=t[i]-'a';
        while(u!=1&&(!(S.ch[u][c]&&seg.qy(l+L,r,1,n,rt[S.ch[u][c]])))){
            L--;
            if(L==S.len[S.fa[u]]) u=S.fa[u];
        }
        if(S.ch[u][c]&&seg.qy(l+L,r,1,n,rt[S.ch[u][c]])) u=S.ch[u][c],L++;
        T.insert(c);
        T.mx[T.last]=L;
    }
    printf("%lld\n",T.gao());
}
int main(){
    scanf("%s",s+1);
    n=strlen(s+1);
    S.init();
    for(int i=1;i<=n;i++){
        S.insert(s[i]-'a');
        seg.up(i,1,n,rt[S.last]);
    }
    S.build();
    scanf("%d",&q);
    for(int i=1;i<=q;i++){
        solve(i);
    }
    return 0;
}
posted @ 2020-12-04 15:55  xyq0220  阅读(99)  评论(0编辑  收藏  举报