Loj2083.「NOI2016」优秀的拆分

「NOI2016」优秀的拆分

#2083. 「NOI2016」优秀的拆分 - 题目 - LibreOJ (loj.ac)

\(Description\)

求将字符串 \(S\) 所有子串拆分为 \(AABB\) 形式的总个数。

\(Solution\)

\(f_{i}\) 表示以位置 \(i\) 开头的 \(AA\) 串的个数,\(g_i\) 表示以位置 \(i\) 结尾的 \(AA\) 串的个数。

那么最后答案为 \(\sum g_i f_{i + 1}\)

字符串 \(hash\) 可以 \(O(n^2)\) 暴力求出 \(f_i\),瓶颈在于如何快速求出 \(f_i\)

考虑求所有长度为 \(2len\)\(AA\) 串,将原串每 \(len\) 位设置关键点,可以发现每个 \(AA\) 串一定经过两个关键点。

发现可以以关键点为界将 \(A\) 分为 \(LCP\)\(LCS\)

考虑求出两个相邻关键点 \(x, \ y\)\(lcp\)\(lcs\),可以发现能形成 \(AA\) 串的充要条件是 \(lcp + lcs > len\),并且形成 \(AA\) 串的开头为一段区间 \([x - lcp + 1, x + lcs - len]\),结尾类似。

求出原串正反的 \(sa\),总时间复杂度 \(O(n \log n)\)

\(Code\)

nclude <bits/stdc++.h>

using namespace std;

#define N 30000
#define L 15

#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define fd(i, x, y) for(int i = x; i >= y; i --)
#define mcp(a, b) memcpy(a, b, sizeof b)
#define Mes(a, x) memset(a, x, sizeof a)

int sl[N + 1], sr[N + 1], Log[N + 1];
int n;

struct SA {
    int sa[N + 1], rk[N + 1], ht[N + 1], id[N + 1], oldrk[N << 1], buc[N + 1], px[N + 1], ft[N + 1][L + 1];
    int sk;
    void mysort() {
        fill(buc, buc + 1 + sk, 0);
        fo(i, 1, n) ++ buc[ px[i] = rk[id[i]] ];
        fo(i, 1, sk) buc[i] += buc[i - 1];
        fd(i, n, 1) sa[ buc[px[i]] -- ] = id[i];
    }
    bool pd(int x, int y, int z) { return oldrk[x] == oldrk[y] && oldrk[x + z] == oldrk[y + z]; }
    void build(char ch[]) {
        Mes(rk, 0), Mes(oldrk, 0);
        sk = 26;
        fo(i, 1, n) rk[ id[i] = i ] = ch[i] - 'a' + 1;
        mysort();
        for (int w = 1, p = 0; w <= n; w <<= 1, p = 0) {
            fo(i, n - w + 1, n) id[ ++ p ] = i;
            fo(i, 1, n) if (sa[i] > w)
                id[ ++ p ] = sa[i] - w;
            mysort();
            mcp(oldrk, rk);
            sk = 0;
            fo(i, 1, n)
                rk[sa[i]] = pd(sa[i], sa[i - 1], w) ? sk : ++ sk;
            if (sk == n) {
                fo(i, 1, n) sa[rk[i]] = i;
                break;
            }
        }
        sk = 0;
        fo(i, 1, n) {
            if (sk) -- sk;
            while (ch[i + sk] == ch[sa[rk[i] - 1] + sk])
                ++ sk;
            ht[rk[i]] = sk;
        }
        fo(i, 1, n) ft[i][0] = ht[i];
        fo(j, 0, L - 1) fo(i, 1, n)
            ft[i][j + 1] = (i + (1 << j) <= n) ? min(ft[i][j], ft[i + (1 << j)][j]) : ft[i][j];
    }
    int dt;
    int get(int l, int r) {
        l = rk[l], r = rk[r];
        if (l > r) swap(l, r);
        if (++ l == r) return ft[l][0];
        dt = Log[r - l + 1];
        return min(ft[l][dt], ft[r - (1 << dt) + 1][dt]);
    }
} s1, s2;

char ch[N + 1];
int main() {
    int T; scanf("%d\n", &T);
    Log[1] = 0;
    fo(i, 2, N)
        Log[i] = Log[i >> 1] + 1;
    while (T --) {
        scanf("%s\n", ch + 1);
        n = strlen(ch + 1);
        s1.build(ch);
        fd(i, (n >> 1), 1)
            swap(ch[i], ch[n - i + 1]);
        s2.build(ch);
        fill(sl + 1, sl + 1 + n, 0);
        fill(sr + 1, sr + 1 + n, 0);
        int l, r, lcp, lcs, dlen, pl, pr;
        fd(len, (n >> 1), 1) {
            fd(k, (n / len) - 1, 1) {
                l = len * k, r = len * (k + 1);
                lcs = s1.get(l, r);
                lcp = s2.get(n - l + 1, n - r + 1);
                if (lcp + lcs > len) {
                    lcp = min(lcp, len), lcs = min(lcs, len);
                    dlen = lcp + lcs - len - 1;
                    ++ sl[l - lcp + 1], -- sl[l - lcp + 1 + dlen + 1];
                    ++ sr[r + lcs - 1], -- sr[r + lcs - 1 - dlen - 1];
                }
            }
        }
        fo(i, 2, n) sl[i] += sl[i - 1];
        fd(i, n - 1, 1) sr[i] += sr[i + 1];
        long long ans = 0;
        fo(i, 2, n) ans += 1ll * sr[i - 1] * sl[i];
        printf("%lld\n", ans);
    }
    return 0;
}
posted @ 2021-05-11 22:09  buzzhou  阅读(63)  评论(0编辑  收藏  举报