2020牛客暑期多校训练营(第八场) H Hard String Problem

题解:\(bin\) 巨的题解其实已经很详细了,这里讲几个可能会踩到的坑,\(kmp\) 求最小循环节时我们要先把原来的串翻倍,也就是在串后面再接上这个串本身,因为 \(kmp\) 中的 \(len - next[len]\) 所得出来的循环节长度是可能的最小循环节长度,但我们要求的是准确的,比如 \(abcdabc\) ,如果直接用 \(kmp\) 求的话,得到的最小循环节会是 \(abcd\) , 展开后为 \(abcdabcdabcd\) , 而原串展开后为 \(abcdabcabcdabc\) ,有明显的不同,然后我们将每个循环节最小表示,判断循环节是否相同,之后就是按 \(bin\) 巨的题解将字符串展开四倍,最短串展开到大于等于最长串的四倍,之后用广义 \(SAM\) 求解即可。

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 6e6 + 50;
const int maxn2 = 3e5 + 50;
string ss[maxn2], s;
int knex[maxn];
int n;
void getNext(int id){
    int tlen = ss[id].size();
    int j = 0, k = -1;
    knex[0] = -1;
    while(j < tlen){
        if(k == -1 || ss[id][j] == ss[id][k]) knex[++j] = ++k;
        else k = knex[k];
    }
}
 
int get_min(int id){
    int len = ss[id].size();
    int i = 0, j = 1, k = 0, t;
    while(i < len && j < len && k < len){
        t = ss[id][(i + k) % len] - ss[id][(j + k) % len];
        if(!t) k++;
        else {
            if(t > 0) i += k + 1;
            else j += k + 1;
            if(i == j) j++;
            k = 0;
        }
    }
    return min(i, j);
}
 
struct state
{
    int len, link, nex[26];
} st[maxn];
int sz, last;
 
void sam_init(){
    st[0].len = 0;
    st[0].link = -1;
    sz = 1, last = 0;
}
 
void sam_extend(int x){
    int cur = sz++;
    st[cur].len = st[last].len + 1;
    int p = last;
    while(p != -1 && !st[p].nex[x]){
        st[p].nex[x] = cur;
        p = st[p].link;
    }
    if(p == -1) st[cur].link = 0;
    else {
        int q = st[p].nex[x];
        if(st[p].len + 1 == st[q].len){
            st[cur].link = q;
        } else {
            int clone = sz++;
            st[clone].len = st[p].len + 1;
            st[clone].link = st[q].link;
            for(int i = 0; i < 26; i++){st[clone].nex[i] = st[q].nex[i];}
            while(p != -1 && st[p].nex[x] == q){
                st[p].nex[x] = clone;
                p = st[p].link;
            }
            st[q].link = st[cur].link = clone;
        }
    }
    last = cur;
}
 
LL val[maxn];
 
bool cmp(const string s1, const string s2){
    return s1.size() < s2.size();
}

int vis[maxn];
int main()
{
    cin >> n;
    for(int i = 1; i <= n; i++){
        cin >> ss[i];
        ss[i] += ss[i];
        getNext(i);
        int len = ss[i].size();
        len = len - knex[len];
        ss[i] = ss[i].substr(0, len);
    }
    for(int i = 1; i <= n; i++){
        int st = get_min(i);
        int len = ss[i].size();
        s = "";
        for(int j = 0; j < len; j++){
            s += ss[i][(st + j) % len];
        }
        ss[i] = s;
    }
     
    int flag = 1;
    for(int i = 2; i <= n; i++){
        if(ss[i] != ss[i - 1]) {
            flag = 0;
            break;
        }
    }
    if(flag){
        cout << -1 << '\n';
        return 0;
    }
    sort(ss + 1, ss + n + 1, cmp);
    for(int i = 2; i <= n; i++){
        s = ss[i];
        for(int j = 1; j <= 3; j++) ss[i] += s;
    }
    s = ss[1];
    while(ss[1].size() < ss[n].size()){
        ss[1] += s;
    }
    sam_init();
    for(int i = 1; i <= n; i++){
        int len = ss[i].size();
        last = 0;
        for(int j = 0; j < len; j++){
            sam_extend(ss[i][j] - 'a');
        }
    }
    for(int i = 1; i <= n; i++){
        int len = ss[i].size();
        int p = 0;
        for(int j = 0; j < len; j++){
            p = st[p].nex[ss[i][j] - 'a'];
            int u = p;
            while(vis[u] != i && u != 0) vis[u] = i, val[u]++, u = st[u].link;
        }
    }
    LL ans = 0;
    for(int i = 1; i < sz; i++){
        if(val[i] == n) ans += st[i].len - st[st[i].link].len;
    }
    cout << ans << '\n';
}
posted @ 2020-08-04 20:32  从小学  阅读(278)  评论(1编辑  收藏  举报