BZOJ1396 识别子串【SAM+SegmentTree】

BZOJ1396 识别子串

给定一个串\(s\),对于串中的每个位置,输出经过这个位置且只在\(s\)中出现一次的子串的最短长度

朴素的想法是,我们要找到那些只出现一次的子串,之后遍历每个串,把串所覆盖的区域区间和串长取\(min\)
考虑优化,根据\(s\)串先建立\(SAM\),然后计算出每个状态的\(endpos\)集合的大小,其中大小为\(1\)的状态所表示的一系列子串必然只在原串中出现一次,对于\(endpos\)大小为\(1\)的某个状态\(u\),其表示的子串的最短长度为\(len_{link_u}+1\),最长长度为\(len_u\),假设子串结束的位置为\(firstpos_u\),那么对于\([firstpos_u-len_{link_u}+1,firstpos_u]\)这段区间,需要和\(len_{link_u}\)\(min\),而对于区间\([firstpos_u-len_u+1,firstpos_u-len_{link_u}]\)来说,区间上的每个位置\(i\)要和\(firstpos_u-i+1\)\(min\),可以在更新的时候只考虑\(firstpos_u\)的贡献,最后计算的时候在减去\(i-1\)即可,所以根据上述方法,需要建立两棵线段树来维护,其中区间取\(min\)可以通过先排序然后直接赋值来解决

//#pragma GCC optimize("O3")
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<bits/stdc++.h>
using namespace std;
function<void(void)> ____ = [](){ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0);};
const int MAXN = 2e5+7;
char s[MAXN];
struct SegmentTree{
    int lazy[MAXN<<2],l[MAXN<<2],r[MAXN<<2];
    #define ls(rt) rt << 1
    #define rs(rt) rt << 1 | 1
    void pushdown(int rt){
        if(!lazy[rt]) return;
        lazy[ls(rt)] = lazy[rt]; lazy[rs(rt)] = lazy[rt];
        lazy[rt] = 0;
    }
    void build(int L, int R, int rt = 1){
        l[rt] = L; r[rt] = R;
        if(l[rt] + 1 == r[rt]){
            lazy[rt] = MAXN;
            return;
        }
        int mid = (L + R) >> 1;
        build(L,mid,ls(rt)); build(mid,R,rs(rt));
    }
    void update(int L, int R, int x, int rt = 1){
        if(l[rt]>=R or L>=r[rt]) return;
        if(L<=l[rt] and r[rt]<=R){
            lazy[rt] = x;
            return;
        }
        pushdown(rt);
        update(L,R,x,ls(rt)); update(L,R,x,rs(rt));
    }
    int query(int pos, int rt = 1){
        if(l[rt] + 1 == r[rt]) return lazy[rt];
        int mid = (l[rt] + r[rt]) >> 1;
        pushdown(rt);
        if(pos<mid) return query(pos,ls(rt));
        else return query(pos,rs(rt));
    }
}ST1,ST2;
struct SAM{
    int len[MAXN],link[MAXN],ch[MAXN][26],tot,last,cnt[MAXN],c[MAXN],sa[MAXN],firstpos[MAXN];
    SAM(){ link[0] = -1; }
    void extend(int c){
        int np = ++tot, p = last;
        firstpos[np] = len[np] = len[p] + 1; cnt[np] = 1;
        while(p!=-1 and !ch[p][c]){
            ch[p][c] = np;
            p = link[p];
        }
        if(p==-1) link[np] = 0;
        else{
            int q = ch[p][c];
            if(len[p]+1==len[q]) link[np] = q;
            else{
                int clone = ++tot;
                len[clone] = len[p] + 1;
                link[clone] = link[q];
                firstpos[clone] = firstpos[q];
                memcpy(ch[clone],ch[q],sizeof(ch[q]));
                link[np] = link[q] = clone;
                while(p!=-1 and ch[p][c]==q){
                    ch[p][c] = clone;
                    p = link[p];
                }
            }
        }
        last = np;
    }
    void Radix_sort(){
        for(int i = 0; i <= tot; i++) c[i] = 0;
        for(int i = 0; i <= tot; i++) c[len[i]]++;
        for(int i = 1; i <= tot; i++) c[i] += c[i-1];
        for(int i = tot; i >= 0; i--) sa[c[len[i]]--] = i;
    }
    void solve(char *s){
        int l = strlen(s);
        for(int i = 0; i < l; i++) extend(s[i]-'a');
        Radix_sort();
        for(int i = tot + 1; i > 1; i--) cnt[link[sa[i]]] += cnt[sa[i]];
        vector<pair<int,pair<int,int> > > vec;
        for(int i = 1; i <= tot; i++) if(cnt[i]==1) vec.emplace_back(make_pair(firstpos[i],make_pair(len[link[i]]+1,len[i])));
        ST1.build(1,l+1); ST2.build(1,l+1);
        sort(vec.begin(),vec.end(),[](const pair<int,pair<int,int>> &lhs, const pair<int,pair<int,int>> &rhs){
            return lhs.second.first > rhs.second.first;
        });
        for(int i = 0; i < (int)vec.size(); i++) ST1.update(vec[i].first-vec[i].second.first+1,vec[i].first+1,vec[i].second.first);
        sort(vec.begin(),vec.end(),[](const pair<int,pair<int,int>> &lhs, const pair<int,pair<int,int>> &rhs){
            return lhs.first > rhs.first;
        });
        for(int i = 0; i < (int)vec.size(); i++) ST2.update(vec[i].first-vec[i].second.second+1,vec[i].first-vec[i].second.first+1,vec[i].first);
        for(int i = 1; i <= l; i++) printf("%d\n",min(ST1.query(i),ST2.query(i)-i+1));
    }
}sam;
int main(){
    scanf("%s",s);
    sam.solve(s);
    return 0;
}
posted @ 2020-04-16 01:52  _kiko  阅读(131)  评论(0编辑  收藏  举报