bzoj 3676 后缀自动机+马拉车+树上倍增

思路:用马拉车把一个串中的回文串个数降到O(n)级别,然后每个串在后缀自动机上倍增找个数。

#include<bits/stdc++.h>
#define LL long long
#define fi first
#define se second
#define mk make_pair
#define PII pair<int, int>
#define PLI pair<LL, int>
#define ull unsigned long long
using namespace std;

const int N = 300000 + 7;
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3f;
const int mod = 1e9 + 7;
const double eps = 1e-8;
const int base = 87;

int n, m, p[N<<1];
char s[N<<1];

struct SuffixAutomaton {
    int last, cur, cnt, ch[N<<1][26], id[N<<1], fa[N<<1], dis[N<<1], sz[N<<1], c[N];
    int f[N<<1][20], pos[N<<1];
    SuffixAutomaton() {cur = cnt = 1;}
    void init() {
        for(int i = 1; i <= cnt; i++) {
            memset(ch[i], 0, sizeof(ch[i]));
            sz[i] = c[i] = dis[i] = fa[i] = 0;
        }
        cur = cnt = 1;
    }
    void extend(int c, int id) {
        last = cur; cur = ++cnt;
        int p = last; dis[cur] = id;
        for(; p && !ch[p][c]; p = fa[p]) ch[p][c] = cur;
        if(!p) fa[cur] = 1;
        else {
            int q = ch[p][c];
            if(dis[q] == dis[p]+1) fa[cur] = q;
            else {
                int nt = ++cnt; dis[nt] = dis[p]+1;
                memcpy(ch[nt], ch[q], sizeof(ch[q]));
                fa[nt] = fa[q]; fa[q] = fa[cur] = nt;
                for(; ch[p][c]==q; p=fa[p]) ch[p][c] = nt;
            }
        }
        sz[cur] = 1;
    }
    void getSize(int n) {
        for(int i = 1; i <= cnt; i++) c[dis[i]]++;
        for(int i = 1; i <= n; i++) c[i] += c[i-1];
        for(int i = cnt; i >= 1; i--) id[c[dis[i]]--] = i;
        for(int i = cnt; i >= 1; i--) {
            int p = id[i];
            sz[fa[p]] += sz[p];
        }
    }
    LL query(int p, int len) {
        for(int j = 19; j >= 0; j--) {
            if(f[p][j] && dis[f[p][j]] >= len) p = f[p][j];
        }
        return 1ll*len*sz[p];
    }
    void solve() {
        for(int i = 1, p = 1; i <= n; i++)
            p = ch[p][s[i]-'a'], pos[i] = p;
        for(int i = 1; i <= cnt; i++) f[i][0] = fa[i];
        for(int j = 1; j < 20; j++)
            for(int i = 1; i <= cnt; i++)
                f[i][j] = f[f[i][j-1]][j-1];

        LL ans = 0;
        s[0] = '-', s[n+1] = '+';
        int mx = 0, id = 0;
        for(int i = 1; i <= n; i++) {
            if(mx > i) p[i] = min(mx-i, p[2*id-i]);
            else p[i]=1, ans = max(ans, query(pos[i], 1));
            while(s[i+p[i]]==s[i-p[i]]) p[i]++, ans = max(ans, query(pos[i+p[i]-1], 2*p[i]-1));
            if(i+p[i]>mx) mx = i+p[i], id = i;
        }
        mx = 0, id = 0;
        for(int i = 1; i <= n; i++) {
            if(mx > i) p[i] = min(mx-i, p[2*id-i]);
            else p[i] = 0;
            while(s[i+p[i]+1]==s[i-p[i]]) p[i]++, ans = max(ans, query(pos[i+p[i]], 2*p[i]));
            if(i+p[i]>mx) mx = i+p[i], id = i;
        }
        printf("%lld\n", ans);
    }
} sam;

int main() {
    scanf("%s", s + 1);
    n = strlen(s + 1);
    for(int i = 1; i <= n; i++)
        sam.extend(s[i]-'a', i);
    sam.getSize(n);
    sam.solve();
    return 0;
}

/*
*/

 

posted @ 2018-10-21 14:31  NotNight  阅读(111)  评论(0编辑  收藏  举报