后缀数组

后缀数组

\(sa_i\) 表示排名为 \(i\) 的后缀为 \(S[sa_i, n]\)\(rk_i\) 表示 \(S[i, n]\) 的排名,即后缀数组(Suffix Array)。不难得到:

\[sa_{rk_i}=rk_{sa_i}=i \]

倍增求解

P3809 【模板】后缀排序

考虑已经将每个串以前 \(2^{k - 1}\) 个字符为关键字排好序,接下来需要以前 \(2^k\) 个字符为关键字排好序。

由于 \(2^k = 2^{k - 1} + 2^{k - 1}\) ,于是只要做一次双关键字排序即可。其中第一关键字为 \(S[i, i + 2^{k - 1} - 1]\) ,第二关键字为 \(S[i + 2^{k - 1}, i + 2^k - 1]\) ,这两个信息的大小比较在上一轮排序中均可找到。

一些优化:

  • 当倍增数量大于字符串数量时,就可以停止倍增。

  • 并且当所有 \(sa_i\) 都不相同时,就可以停止排序。

使用计数排序可以做到 \(O(n \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 7;

char str[N];

int n;

namespace SA {
int sa[N], rk[N], cnt[N], id[N], oldrk[N << 1];

inline void prework(int n) {
    memset(oldrk + n + 1, 0, sizeof(int) * n);
    int m = 1 << 8;
    memset(cnt + 1, 0, sizeof(int) * m);

    for (int i = 1; i <= n; ++i)
        ++cnt[rk[i] = str[i]];

    for (int i = 1; i <= m; ++i)
        cnt[i] += cnt[i - 1];

    for (int i = n; i; --i)
        sa[cnt[rk[i]]--] = i;

    for (int k = 1; k < n; k <<= 1) {
        int tot = 0;

        for (int i = n; i > n - k; --i)
            id[++tot] = i;

        for (int i = 1; i <= n; ++i)
            if (sa[i] > k)
                id[++tot] = sa[i] - k;

        memset(cnt + 1, 0, sizeof(int) * m);

        for (int i = 1; i <= n; ++i)
            ++cnt[rk[id[i]]];

        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];

        for (int i = n; i; --i)
            sa[cnt[rk[id[i]]]--] = id[i];

        memcpy(oldrk + 1, rk + 1, sizeof(int) * n), m = 0;

        for (int i = 1; i <= n; ++i) {
            if (oldrk[sa[i]] != oldrk[sa[i - 1]] || oldrk[sa[i] + k] != oldrk[sa[i - 1] + k])
                ++m;

            rk[sa[i]] = m;
        }

        if (m == n)
            break;
    }
}
} // namespace SA

signed main() {
    scanf("%s", str + 1), n = strlen(str + 1);
    SA::prework(n);

    for (int i = 1; i <= n; ++i)
        printf("%d ", SA::sa[i]);

    return 0;
}

LCP 相关

\(\mathrm{lcp}(i, j) = \mathrm{LCP}(S[sa_i, n], S[sa_j, n]), ht_i = \mathrm{lcp}(i - 1, i)\) ,则有:

  • LCP Lemma:\(\forall 1 \le i < j < k \le n, \mathrm{lcp}(i, k) = \min(\mathrm{lcp}(i, j), \mathrm{lcp}(j, k))\)
  • LCP Theorem:\(\forall i \le i < j \le n, \mathrm{lcp}(i, j) = \min_{k = i + 1}^j {ht_k}\)
  • LCP Corollary:\(\forall 1 \le i \le j < k \le n, \mathrm{lcp}(i, j) \ge \mathrm{lcp}(i, k)\)

于是求解两个后缀的 LCP 可以转化为 \(ht\) 数组上的 RMQ 问题。

\(h_i = ht_{rk_i}\) ,则:

\[h_i \ge h_{i - 1} - 1 \]

证明:在 \(sa\) 数组里面找一个后缀,设它在原字符串的下标为 \(i-1\) ,前一个后缀在原字符串的下标为 \(k\) 。把它们两个后缀的首字母都砍掉,它们就变成了 \(i\)\(k+1\)

当两个字符串首字母不同时,它们的 LCP 就是 \(0\) 。否则删除首字母后排名先后肯定也是不变的,且它们的 LCP 长度为 \(h_{i-1}-1\) 。而由 LCP Lemma 可知,这个 LCP 长度是这个区间中最小的。因此 \(h_i \ge h_{i-1}-1\)

用指针维护 \(h_i\) 可以线性求出 \(ht\)

for (int i = 1, k = 0; i <= n; ++i) {
    if (k)
        --k; // h[i] >= h[i - 1] - 1

    int j = sa[rk[i] - 1]; // j 是 i 相邻的一个后缀

    while (i + k <= n && j + k <= n && str[i + k] == str[j + k])
        ++k;

    ht[rk[i]] = k;
}

应用

P2408 不同子串个数

求长为 \(n\) 的字符串的不同子串个数。

\(n \le 10^5\)

发现每个子串都是一个后缀的前缀,考虑在排名最小的后缀处统计贡献。那么多出来的部分就是排名相邻的两个后缀的 LCP ,答案即为 \(\frac{n(n + 1)}{2} - \sum_{i = 2}^n ht_i\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7;

char str[N];

int n;

namespace SA {
int sa[N], rk[N], cnt[N], id[N], oldrk[N], ht[N];

inline void prework() {
    int m = 1 << 9;
    memset(cnt, 0, sizeof(int) * (m + 1));

    for (int i = 1; i <= n; ++i)
        ++cnt[rk[i] = str[i]];

    for (int i = 1; i <= m; ++i)
        cnt[i] += cnt[i - 1];

    for (int i = n; i; --i)
        sa[cnt[rk[i]]--] = i;

    for (int k = 1; k < n; k <<= 1) {
        int tot = 0;

        for (int i = n; i > n - k; --i)
            id[++tot] = i;

        for (int i = 1; i <= n; ++i)
            if (sa[i] > k)
                id[++tot] = sa[i] - k;

        memset(cnt, 0, sizeof(int) * (m + 1));

        for (int i = 1; i <= n; ++i)
            ++cnt[rk[id[i]]];

        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];

        for (int i = n; i; --i)
            sa[cnt[rk[id[i]]]--] = id[i];

        memcpy(oldrk + 1, rk + 1, sizeof(int) * n), tot = 0;

        for (int i = 1; i <= n; ++i) {
            if (oldrk[sa[i]] != oldrk[sa[i - 1]] || oldrk[sa[i] + k] != oldrk[sa[i - 1] + k])
                ++tot;

            rk[sa[i]] = tot;
        }

        if (tot == n)
            break;

        m = tot;
    }

    for (int i = 1, k = 0; i <= n; ++i) {
        if (k)
            --k;

        int j = sa[rk[i] - 1];

        while (i + k <= n && j + k <= n && str[i + k] == str[j + k])
            ++k;

        ht[rk[i]] = k;
    }
}
} // namespace SA

signed main() {
    scanf("%d%s", &n, str + 1);
    SA::prework();
    ll ans = 1ll * n * (n + 1) / 2;

    for (int i = 2; i <= n; ++i)
        ans -= SA::ht[i];

    printf("%lld", ans);
    return 0;
}

SP1811 LCS - Longest Common Substring

求两个串的最长公共子串。

\(n \le 2.5 \times 10^5\)

把两个串拼起来,中间塞一个无关字符,答案即为 \(ht\) 数组的最大值(需要保证前后两个后缀串分别来自给出的两个串)。

#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 7;

char s[N], p[N];

int n, m;

namespace SA {
int sa[N], rk[N], cnt[N], id[N], oldrk[N], ht[N];

inline void prework(char *str, int n) {
    int m = 1 << 9;
    memset(cnt, 0, sizeof(int) * (m + 1));

    for (int i = 1; i <= n; ++i)
        ++cnt[rk[i] = str[i]];

    for (int i = 1; i <= m; ++i)
        cnt[i] += cnt[i - 1];

    for (int i = n; i; --i)
        sa[cnt[rk[i]]--] = i;

    for (int k = 1; k < n; k <<= 1) {
        int tot = 0;

        for (int i = n; i > n - k; --i)
            id[++tot] = i;

        for (int i = 1; i <= n; ++i)
            if (sa[i] > k)
                id[++tot] = sa[i] - k;

        memset(cnt, 0, sizeof(int) * (m + 1));

        for (int i = 1; i <= n; ++i)
            ++cnt[rk[id[i]]];

        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];

        for (int i = n; i; --i)
            sa[cnt[rk[id[i]]]--] = id[i];

        memcpy(oldrk + 1, rk + 1, sizeof(int) * n), tot = 0;

        for (int i = 1; i <= n; ++i) {
            if (oldrk[sa[i]] != oldrk[sa[i - 1]] || oldrk[sa[i] + k] != oldrk[sa[i - 1] + k])
                ++tot;

            rk[sa[i]] = tot;
        }

        if (tot == n)
            break;

        m = tot;
    }

    for (int i = 1, k = 0; i <= n; ++i) {
        if (k)
            --k;

        int j = sa[rk[i] - 1];

        while (i + k <= n && j + k <= n && str[i + k] == str[j + k])
            ++k;

        ht[rk[i]] = k;
    }
}
} // namespace SA

signed main() {
    scanf("%s%s", s + 1, p + 1);
    n = strlen(s + 1), m = strlen(p + 1);
    s[n + 1] = '#', memcpy(s + n + 2, p + 1, sizeof(char) * m);
    SA::prework(s, n + 1 + m);
    int ans = 0;

    for (int i = 2; i <= n + 1 + m; ++i)
        if (min(SA::sa[i - 1], SA::sa[i]) <= n && max(SA::sa[i - 1], SA::sa[i]) > n)
            ans = max(ans, SA::ht[i]);

    printf("%d", ans);
    return 0;
}

P3181 [HAOI2016] 找相同字符

给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。

\(n \le 2 \times 10^5\)

本质上就是算所有后缀两两之间的 LCP 和,但是需要来自两个不同串。考虑把两个串拼起来(中间塞一个无关字符)做一次,然后再减去两个串分别做一次的贡献。问题转化为 \(ht\) 数组上所有子区间的最小值之和,不难用单调栈解决。

类似的题:P4248 [AHOI2013] 差异

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 4e5 + 7;

int L[N], R[N], sta[N];
char s[N], p[N];

int n, m;

namespace SA {
int sa[N], rk[N], cnt[N], id[N], oldrk[N], ht[N];

inline void prework(char *str, int n) {
    int m = 1 << 9;
    memset(cnt, 0, sizeof(int) * (m + 1));

    for (int i = 1; i <= n; ++i)
        ++cnt[rk[i] = str[i]];

    for (int i = 1; i <= m; ++i)
        cnt[i] += cnt[i - 1];

    for (int i = n; i; --i)
        sa[cnt[rk[i]]--] = i;

    for (int k = 1; k < n; k <<= 1) {
        int tot = 0;

        for (int i = n; i > n - k; --i)
            id[++tot] = i;

        for (int i = 1; i <= n; ++i)
            if (sa[i] > k)
                id[++tot] = sa[i] - k;

        memset(cnt, 0, sizeof(int) * (m + 1));

        for (int i = 1; i <= n; ++i)
            ++cnt[rk[id[i]]];

        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];

        for (int i = n; i; --i)
            sa[cnt[rk[id[i]]]--] = id[i];

        memcpy(oldrk + 1, rk + 1, sizeof(int) * n), tot = 0;

        for (int i = 1; i <= n; ++i) {
            if (oldrk[sa[i]] != oldrk[sa[i - 1]] || oldrk[sa[i] + k] != oldrk[sa[i - 1] + k])
                ++tot;

            rk[sa[i]] = tot;
        }

        if (tot == n)
            break;

        m = tot;
    }

    for (int i = 1, k = 0; i <= n; ++i) {
        if (k)
            --k;

        int j = sa[rk[i] - 1];

        while (i + k <= n && j + k <= n && str[i + k] == str[j + k])
            ++k;

        ht[rk[i]] = k;
    }
}
} // namespace SA

inline ll solve(char *str, int n) {
    SA::prework(str, n);

    for (int i = 2, top = 0; i <= n; ++i) {
        while (top && SA::ht[sta[top]] > SA::ht[i])
            --top;

        L[i] = top ? sta[top] + 1 : 2, sta[++top] = i;
    }

    for (int i = n, top = 0; i >= 2; --i) {
        while (top && SA::ht[sta[top]] >= SA::ht[i])
            --top;

        R[i] = top ? sta[top] - 1 : n, sta[++top] = i;
    }

    ll res = 0;

    for (int i = 2; i <= n; ++i)
        res += 1ll * SA::ht[i] * (i - L[i] + 1) * (R[i] - i + 1);

    return res;
}

signed main() {
    scanf("%s%s", s + 1, p + 1);
    n = strlen(s + 1), m = strlen(p + 1);
    ll res = solve(s, n) + solve(p, m);
    s[n + 1] = '#', memcpy(s + n + 2, p + 1, sizeof(char) * m);
    printf("%lld", solve(s, n + 1 + m) - res);
    return 0;
}

P4094 [HEOI2016/TJOI2016] 字符串

给定字符串, \(m\) 次询问 \(S[a, b]\) 所有子串与 \(S[c, d]\) 的 LCP 最大值。

\(n, m \le 10^5\)

二分答案 \(mid\) ,问题转化为 \([a, a + mid - 1]\) 的所有后缀与 \(S[c, d]\) 的 LCP 的最大值是否 \(\geq mid\) ,而与 \(S[c, d]\) 的 LCP 至少为 \(mid\) 的后缀在 \(sa\) 上形如一段区间,问题转化为二维数点,直接主席树即可做到 \(O(n \log^2 n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7, LOGN = 17;

char str[N];

int n, m;

namespace SA {
int sa[N], rk[N], cnt[N], id[N], oldrk[N], ht[N];

inline void prework() {
    int m = 1 << 9;

    for (int i = 1; i <= n; ++i)
        ++cnt[rk[i] = str[i]];

    for (int i = 1; i <= m; ++i)
        cnt[i] += cnt[i - 1];

    for (int i = n; i; --i)
        sa[cnt[rk[i]]--] = i;

    for (int k = 1; k < n; k <<= 1) {
        int tot = 0;

        for (int i = n; i > n - k; --i)
            id[++tot] = i;

        for (int i = 1; i <= n; ++i)
            if (sa[i] > k)
                id[++tot] = sa[i] - k;

        memset(cnt, 0, sizeof(int) * (m + 1));

        for (int i = 1; i <= n; ++i)
            ++cnt[rk[id[i]]];

        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];

        for (int i = n; i; --i)
            sa[cnt[rk[id[i]]]--] = id[i];

        memcpy(oldrk + 1, rk + 1, sizeof(int) * n), tot = 0;

        for (int i = 1; i <= n; ++i) {
            if (oldrk[sa[i]] != oldrk[sa[i - 1]] || oldrk[sa[i] + k] != oldrk[sa[i - 1] + k])
                ++tot;

            rk[sa[i]] = tot;
        }

        if (tot == n)
            break;

        m = tot;
    }

    for (int i = 1, k = 0; i <= n; ++i) {
        if (k)
            --k;

        int j = sa[rk[i] - 1];

        while (i + k <= n && j + k <= n && str[i + k] == str[j + k])
            ++k;

        ht[rk[i]] = k;
    }
}
} // namespace SA

namespace ST {
int f[LOGN][N];

inline void prework() {
    memcpy(f[0] + 1, SA::ht + 1, sizeof(int) * n);

    for (int j = 1; j <= __lg(n); ++j)
        for (int i = 1; i + (1 << j) - 1 <= n; ++i)
            f[j][i] = min(f[j - 1][i], f[j - 1][i + (1 << (j - 1))]);
}

inline int query(int l, int r) {
    int k = __lg(r - l + 1);
    return min(f[k][l], f[k][r - (1 << k) + 1]);
}
} // namespace ST

namespace SMT {
const int S = 1e7 + 7;

int lc[S], rc[S], cnt[S], rt[N];

int tot;

int insert(int x, int nl, int nr, int p) {
    int y = ++tot;
    lc[y] = lc[x], rc[y] = rc[x], cnt[y] = cnt[x] + 1;

    if (nl == nr)
        return y;

    int mid = (nl + nr) >> 1;

    if (p <= mid)
        lc[y] = insert(lc[x], nl, mid, p);
    else
        rc[y] = insert(rc[x], mid + 1, nr, p);

    return y;
}

int query(int x, int y, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return cnt[y] - cnt[x];

    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return query(lc[x], lc[y], nl, mid, l, r);
    else if (l > mid)
        return query(rc[x], rc[y], mid + 1, nr, l, r);
    else
        return query(lc[x], lc[y], nl, mid, l, r) + query(rc[x], rc[y], mid + 1, nr, l, r);
}
} // namespace SMT

inline bool check(int ql, int qr, int x, int k) {
    int l = 1, r = SA::rk[x] - 1, hl = SA::rk[x];

    while (l <= r) {
        int mid = (l + r) >> 1;

        if (ST::query(mid + 1, SA::rk[x]) >= k)
            hl = mid, r = mid - 1;
        else
            l = mid + 1;
    }

    int hr = SA::rk[x];
    l = SA::rk[x] + 1, r = n;

    while (l <= r) {
        int mid = (l + r) >> 1;

        if (ST::query(SA::rk[x] + 1, mid) >= k)
            hr = mid, l = mid + 1;
        else
            r = mid - 1;
    }

    return SMT::query(SMT::rt[hr], SMT::rt[hl - 1], 1, n, ql, qr);
}

signed main() {
    scanf("%d%d%s", &n, &m, str + 1);
    SA::prework(), ST::prework();

    for (int i = 1; i <= n; ++i)
        SMT::rt[i] = SMT::insert(SMT::rt[i - 1], 1, n, SA::sa[i]);

    while (m--) {
        int a, b, c, d;
        scanf("%d%d%d%d", &a, &b, &c, &d);
        int l = 0, r = min(b - a + 1, d - c + 1), ans = 0;

        while (l <= r) {
            int mid = (l + r) >> 1;

            if (check(a, b - mid + 1, c, mid))
                ans = mid, l = mid + 1;
            else
                r = mid - 1;
        }

        printf("%d\n", ans);
    }

    return 0;
}

P2178 [NOI2015] 品酒大会

给定长度为 \(n\) 的字符串与权值 \(a_{1 \sim n}\) 。对于所有 \(i \in [1, n]\) ,求有多少对后缀满足 LCP 长度 \(\ge i\) 以及满足条件的两个后缀权值乘积的最大值。

\(n \le 3 \times 10^5\)

考虑倒序枚举 \(i\) ,则每次都会将 \(ht = i\) 的位置的两个后缀合并为一个连通块,两个问题的答案都可以通过并查集维护连通块来处理,时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 1e18;
const int N = 3e5 + 7;

vector<int> ht[N];

pair<ll, ll> ans[N];
int a[N];
char str[N];

pair<ll, ll> Answer = make_pair(0ll, -inf);
int n;

namespace SA {
int sa[N], rk[N], cnt[N], id[N], oldrk[N], ht[N];

inline void prework() {
    int m = 1 << 9;
    memset(cnt, 0, sizeof(int) * (m + 1));

    for (int i = 1; i <= n; ++i)
        ++cnt[rk[i] = str[i]];

    for (int i = 1; i <= m; ++i)
        cnt[i] += cnt[i - 1];

    for (int i = n; i; --i)
        sa[cnt[rk[i]]--] = i;

    for (int k = 1; k < n; k <<= 1) {
        int tot = 0;

        for (int i = n; i > n - k; --i)
            id[++tot] = i;

        for (int i = 1; i <= n; ++i)
            if (sa[i] > k)
                id[++tot] = sa[i] - k;

        memset(cnt, 0, sizeof(int) * (m + 1));

        for (int i = 1; i <= n; ++i)
            ++cnt[rk[id[i]]];

        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];

        for (int i = n; i; --i)
            sa[cnt[rk[id[i]]]--] = id[i];

        memcpy(oldrk + 1, rk + 1, sizeof(int) * n), tot = 0;

        for (int i = 1; i <= n; ++i) {
            if (oldrk[sa[i]] != oldrk[sa[i - 1]] || oldrk[sa[i] + k] != oldrk[sa[i - 1] + k])
                ++tot;

            rk[sa[i]] = tot;
        }

        if (tot == n)
            break;

        m = tot;
    }

    for (int i = 1, k = 0; i <= n; ++i) {
        if (k)
            --k;

        int j = sa[rk[i] - 1];

        while (i + k <= n && j + k <= n && str[i + k] == str[j + k])
            ++k;

        ht[rk[i]] = k;
    }
}
} // namespace SA

namespace DSU {
int fa[N], siz[N], mx[N], mn[N];

inline void prework(int n) {
    iota(fa + 1, fa + 1 + n, 1);
    fill(siz + 1, siz + 1 + n, 1);
}

inline int find(int x) {
    while (x != fa[x])
        fa[x] = fa[fa[x]], x = fa[x];

    return x;
}

inline void merge(int x, int y) {
    int fx = find(x), fy = find(y);
    Answer.first += 1ll * siz[fx] * siz[fy];
    Answer.second = max(Answer.second, max(1ll * mx[fx] * mx[fy], 1ll * mn[fx] * mn[fy]));
    mx[fx] = max(mx[fx], mx[fy]),  mn[fx] = min(mn[fx], mn[fy]);
    siz[fx] += siz[fy], fa[fy] = fx;
}
} // namespace DSU

signed main() {
    scanf("%d%s", &n, str + 1);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    SA::prework(), DSU::prework(n);

    for (int i = 1; i <= n; ++i)
        DSU::mx[i] = DSU::mn[i] = a[SA::sa[i]];

    for (int i = 2; i <= n; ++i)
        ht[SA::ht[i]].emplace_back(i);

    for (int i = n - 1; ~i; --i) {
        for (int x : ht[i])
            DSU::merge(x, x - 1);

        ans[i] = Answer;
    }

    for (int i = 0; i < n; ++i)
        printf("%lld %lld\n", ans[i].first, ans[i].second == -inf ? 0 : ans[i].second);

    return 0;
}

P7361 「JZOI-1」拜神

给定一个字符串,\(q\) 次询问区间最长的至少出现两次的字符串长度。

\(n \le 5 \times 10^4\)\(q \le 10^5\)

考虑二分答案 \(k\) ,问题转化为判定 \([l, r - k + 1]\) 内是否有两个后缀的 LCP 长度 \(\ge k\)

考虑对每个位置 \(i\) 维护 \(p_{i, j}\) 表示 \(< i\) 的最靠后的位置使得二者后缀的 LCP 长度 \(\ge j\) ,则判定条件等价于 \(l \le \max_{i = l}^{r - k + 1} p_{i, k}\)

维护 \(p_{i, j}\) 则考虑 SA-height 结构,用并查集维护两两 LCP 长度 \(\ge j\) 的位置。降序枚举 \(j\) ,则每次相当于合并若干个并查集。每次做启发式合并,枚举小集合的元素,并在大集合中找到其邻域做修改。

总共会做 \(O(n \log n)\) 次单点取 \(\max\) ,不难用主席树维护 \(p_{i, j}\) 做到时间复杂度 \(O(n \log^2 n + q \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 5e4 + 7;

vector<int> vec[N];

char str[N];

int n, q;

namespace SA {
int sa[N], rk[N], cnt[N], id[N], oldrk[N << 1], ht[N];

inline void prework() {
    int m = 1 << 8;
    memset(cnt, 0, sizeof(int) * (m + 1));
    memset(oldrk + n + 1, 0, sizeof(int) * n);

    for (int i = 1; i <= n; ++i)
        ++cnt[rk[i] = str[i]];

    for (int i = 1; i <= m; ++i)
        cnt[i] += cnt[i - 1];

    for (int i = n; i; --i)
        sa[cnt[rk[i]]--] = i;

    for (int k = 1; k < n; k <<= 1) {
        int tot = 0;

        for (int i = n; i > n - k; --i)
            id[++tot] = i;

        for (int i = 1; i <= n; ++i)
            if (sa[i] > k)
                id[++tot] = sa[i] - k;

        memset(cnt, 0, sizeof(int) * (m + 1));

        for (int i = 1; i <= n; ++i)
            ++cnt[rk[id[i]]];

        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];

        for (int i = n; i; --i)
            sa[cnt[rk[id[i]]]--] = id[i];

        memcpy(oldrk + 1, rk + 1, sizeof(int) * n), m = 0;

        for (int i = 1; i <= n; ++i) {
            if (oldrk[sa[i]] != oldrk[sa[i - 1]] || oldrk[sa[i] + k] != oldrk[sa[i - 1] + k])
                ++m;

            rk[sa[i]] = m;
        }

        if (m == n)
            break;
    }

    for (int i = 1, k = 0; i <= n; ++i) {
        if (k)
            --k;

        int j = sa[rk[i] - 1];

        while (i + k <= n && j + k <= n && str[i + k] == str[j + k])
            ++k;

        ht[rk[i]] = k;
    }
}
} // namespace SA

namespace SMT {
const int S = 4e7 + 7;

int rt[N], lc[S], rc[S], mx[S];

int tot, root;

int update(int x, int nl, int nr, int p, int k) {
    int y = ++tot;
    lc[y] = lc[x], rc[y] = rc[x], mx[y] = mx[x];

    if (nl == nr)
        return mx[y] = max(mx[y], k), y;

    int mid = (nl + nr) >> 1;

    if (p <= mid)
        lc[y] = update(lc[x], nl, mid, p, k);
    else
        rc[y] = update(rc[x], mid + 1, nr, p, k);

    mx[y] = max(mx[lc[y]], mx[rc[y]]);
    return y;
}

int query(int x, int nl, int nr, int l, int r) {
    if (!x)
        return 0;

    if (l <= nl && nr <= r)
        return mx[x];

    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return query(lc[x], nl, mid, l, r);
    else if (l > mid)
        return query(rc[x], mid + 1, nr, l, r);
    else
        return max(query(lc[x], nl, mid, l, r), query(rc[x], mid + 1, nr, l, r));
}
} // namespace SMT

namespace DSU {
set<int> st[N];

int fa[N];

inline void prework(int n) {
    iota(fa + 1, fa + n + 1, 1);

    for (int i = 1; i <= n; ++i)
        st[i].emplace(i);
}

inline int find(int x) {
    while (x != fa[x])
        x = fa[x];

    return x;
}

inline void merge(int x, int y) {
    x = find(x), y = find(y);

    if (x == y)
        return;

    if (st[x].size() < st[y].size())
        swap(x, y);

    fa[y] = x;

    for (int it : st[y]) {
        auto nxt = st[x].lower_bound(it);

        if (nxt != st[x].end())
            SMT::root = SMT::update(SMT::root, 1, n, *nxt, it);

        if (nxt != st[x].begin())
            SMT::root = SMT::update(SMT::root, 1, n, it, *prev(nxt));
    }

    for (int it : st[y])
        st[x].emplace(it);
}
} // namespace DSU

signed main() {
    scanf("%d%d%s", &n, &q, str + 1);
    SA::prework();

    for (int i = 2; i <= n; ++i)
        vec[SA::ht[i]].emplace_back(i);

    DSU::prework(n);

    for (int i = n; i; --i) {
        for (int it : vec[i])
            DSU::merge(SA::sa[it], SA::sa[it - 1]);

        SMT::rt[i] = SMT::root;
    }

    while (q--) {
        int L, R;
        scanf("%d%d", &L, &R);
        int l = 1, r = R - L + 1, ans = 0;

        while (l <= r) {
            int mid = (l + r) >> 1;

            if (SMT::query(SMT::rt[mid], 1, n, L, R - mid + 1) >= L)
                ans = mid, l = mid + 1;
            else
                r = mid - 1;
        }

        printf("%d\n", ans);
    }

    return 0;
}

P5284 [十二省联考 2019] 字符串问题

现有一个字符串 \(S\) ,从中划出 \(n_a\) 个子串作为 A 类串,再划出 \(n_b\) 个子串作为 B 类串。

额外给定 \(m\) 组支配关系,每组支配关系 \((x, y)\) 表示第 \(x\) 个 A 类串支配第 \(y\) 个 B 类串。

求长度最大的串 \(T\) 长度,满足存在一个串 \(T\) 的分割 \(T = t_1 + t_2 + \cdots + t_k\) 满足:

  • 分割中的每个串 \(t_i\) 均为 A 类串。
  • 对于分割中所有相邻的串 \(t_i, t_{i + 1}\) ,都存在一个 \(t_i\) 支配的 B 类串为 \(t_{i + 1}\) 的前缀。

特别地,若存在无限长的 \(T\) ,输出 \(-1\)

\(|S|, n_a, n_b \le 2 \times 10^5\)

题目转化为:给定一些从 A 类区间连向 B 类区间的边,一个 B 区间能连向一个 A 区间当且仅当前者是后者的前缀。求这张图上的最长路(或判断无限长)。

显然若有环则答案可以无限大,否则直接在 DAG 上 DP 即可。

考虑优化 B 区间到 A 区间的边,发现包含某个 B 串作为前缀的 A 串需要满足两个条件:

  • 以二者左端点开始的后缀的 LCP 长度不小于 B 串长度:转化为对应着后缀数组上的一段区间。
  • A 串长度不小于 B 串长度:按照长度从大到小的顺序加入每个串即可。

使用主席树优化建图,时间复杂度 \(O((n + n_a + n_b) \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 2e5 + 7, LOGN = 19, S = N << 5;

struct Interval {
    int l, r;
} a[N], b[N];

struct Node {
    int x, len, op, id;

    inline Node() {}

    inline Node(int _x, int _len, int _op, int _id) : x(_x), len(_len), op(_op), id(_id) {}

    inline bool operator < (const Node &rhs) const {
        return len == rhs.len ? op < rhs.op : len > rhs.len;
    }
};

struct Graph {
    vector<int> e[S];

    int indeg[S];

    inline void clear(int n) {
        for (int i = 1; i <= n; ++i)
            e[i].clear(), indeg[i] = 0;
    }
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v), ++indeg[v];
    }
} G;

ll f[S];
char str[N];

int n, na, nb, m, tot;

namespace SA {
int sa[N], rk[N], cnt[N], id[N], oldrk[N], ht[N];

inline void prework() {
    memset(oldrk + n + 1, 0, sizeof(int) * n);
    int m = 1 << 9;
    memset(cnt, 0, sizeof(int) * (m + 1));

    for (int i = 1; i <= n; ++i)
        ++cnt[rk[i] = str[i]];

    for (int i = 1; i <= m; ++i)
        cnt[i] += cnt[i - 1];

    for (int i = n; i; --i)
        sa[cnt[rk[i]]--] = i;

    for (int k = 1; k < n; k <<= 1) {
        int tot = 0;

        for (int i = n; i > n - k; --i)
            id[++tot] = i;

        for (int i = 1; i <= n; ++i)
            if (sa[i] > k)
                id[++tot] = sa[i] - k;

        memset(cnt, 0, sizeof(int) * (m + 1));

        for (int i = 1; i <= n; ++i)
            ++cnt[rk[id[i]]];

        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];

        for (int i = n; i; --i)
            sa[cnt[rk[id[i]]]--] = id[i];

        memcpy(oldrk + 1, rk + 1, sizeof(int) * n), tot = 0;

        for (int i = 1; i <= n; ++i) {
            if (oldrk[sa[i]] != oldrk[sa[i - 1]] || oldrk[sa[i] + k] != oldrk[sa[i - 1] + k])
                ++tot;

            rk[sa[i]] = tot;
        }

        if (tot == n)
            break;

        m = tot;
    }

    for (int i = 1, k = 0; i <= n; ++i) {
        if (k)
            --k;

        int j = sa[rk[i] - 1];

        while (i + k <= n && j + k <= n && str[i + k] == str[j + k])
            ++k;

        ht[rk[i]] = k;
    }
}
} // namespace SA

namespace ST {
int f[LOGN][N];

inline void prework() {
    memcpy(f[0] + 1, SA::ht + 1, sizeof(int) * n);

    for (int j = 1; j <= __lg(n); ++j)
        for (int i = 1; i + (1 << j) - 1 <= n; ++i)
            f[j][i] = min(f[j - 1][i], f[j - 1][i + (1 << (j - 1))]);
}

inline int query(int l, int r) {
    int k = __lg(r - l + 1);
    return min(f[k][l], f[k][r - (1 << k) + 1]);
}
} // namespace ST

namespace SMT {
int lc[S], rc[S];

int root, tot;

inline void clear() {
    G.clear(tot), root = 0;

    for (; tot; --tot)
        lc[tot] = rc[tot] = 0;
}

int insert(int x, int nl, int nr, int p, int k) {
    int y = ++tot;
    lc[y] = lc[x], rc[y] = rc[x];

    if (x)
        G.insert(y, x);

    if (nl == nr) {
        G.insert(y, k);
        return y;
    }

    int mid = (nl + nr) >> 1;

    if (p <= mid)
        G.insert(y, lc[y] = insert(lc[x], nl, mid, p, k));
    else
        G.insert(y, rc[y] = insert(rc[x], mid + 1, nr, p, k));

    return y;
}

void update(int x, int nl, int nr, int l, int r, int k) {
    if (!x)
        return;

    if (l <= nl && nr <= r) {
        G.insert(k, x);
        return;
    }

    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(lc[x], nl, mid, l, r, k);

    if (r > mid)
        update(rc[x], mid + 1, nr, l, r, k);
}
} // namespace SMT

inline ll TopoSort() {
    memset(f + 1, 0, sizeof(ll) * SMT::tot);
    queue<int> q;
    ll ans = 0;
    int cnt = 0;

    for (int i = 1; i <= SMT::tot; ++i)
        if (!G.indeg[i])
            q.emplace(i);

    while (!q.empty()) {
        int u = q.front();
        q.pop(), ++cnt;
        ans = max(ans, f[u] += (u <= na ? a[u].r - a[u].l + 1 : 0));

        for (int v : G.e[u]) {
            f[v] = max(f[v], f[u]), --G.indeg[v];

            if (!G.indeg[v])
                q.emplace(v);
        }
    }

    return cnt == SMT::tot ? ans : -1;
}

signed main() {
    int T;
    scanf("%d", &T);

    while (T--) {
        scanf("%s%d", str + 1, &na), n = strlen(str + 1);
        SA::prework(), ST::prework();
        vector<Node> nd;

        for (int i = 1; i <= na; ++i) {
            scanf("%d%d", &a[i].l, &a[i].r);
            nd.emplace_back(SA::rk[a[i].l], a[i].r - a[i].l + 1, 0, i);
        }

        scanf("%d", &nb);

        for (int i = 1; i <= nb; ++i) {
            scanf("%d%d", &b[i].l, &b[i].r);
            nd.emplace_back(SA::rk[b[i].l], b[i].r - b[i].l + 1, 1, na + i);
        }

        sort(nd.begin(), nd.end()), SMT::clear(), SMT::tot = na + nb;

        for (Node it : nd) {
            if (it.op) {
                int l = it.x, r = it.x;

                for (int j = LOGN - 1; ~j; --j) {
                    if (l - (1 << j) >= 1 && ST::query(l - (1 << j) + 1, it.x) >= it.len)
                        l -= 1 << j;

                    if (r + (1 << j) <= n && ST::query(it.x + 1, r + (1 << j)) >= it.len)
                        r += 1 << j;
                }

                SMT::update(SMT::root, 1, n, l, r, it.id);
            } else
                SMT::root = SMT::insert(SMT::root, 1, n, it.x, it.id);
        }

        scanf("%d", &m);

        for (int i = 1; i <= m; ++i) {
            int u, v;
            scanf("%d%d", &u, &v);
            G.insert(u, na + v);
        }
        
        printf("%lld\n", TopoSort());
    }

    return 0;
}

相似子串

定义字符串 \(A\) 与字符串 \(B\) 相似当且仅当 \([A_i = A_j] \iff [B_i = B_j]\)

给定长度为 \(n\) 的数字串 \(S\)\(q\) 次询问与某个子串相似的子串数量,强制在线。

\(n \le 10^5\)\(q \le 5 \times 10^5\)

考虑对于一个子串,每个位置的权值定义为与上一个相同字符的下标差,若在该子串中首次出现则为 \(0\) ,则两个字符串相似当且仅当它们每个位置的权值相同。

首先对原权值构建出 \(ht\) 数组,然后可以注意到截取一段子串时,最多只会有 \(10\) 个位置变为 \(0\) ,那么二分时分段比较即可,

#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 1e5 + 7, LOGN = 17, S = 10;

struct ST {
    int f[LOGN][N];

    inline int &operator [] (const int &x) {
        return f[0][x];
    }

    inline void prework(int n) {
        for (int j = 1; j <= __lg(n); ++j)
            for (int i = 1; i + (1 << j) - 1 <= n; ++i)
                f[j][i] = min(f[j - 1][i], f[j - 1][i + (1 << (j - 1))]);
    }

    inline int query(int l, int r) {
        if (l > r)
            return inf;

        int k = __lg(r - l + 1);
        return min(f[k][l], f[k][r - (1 << k) + 1]);
    }
} st;

vector<int> place[N];

int nxt[N][S], a[N], id[N], rid[N];
char str[N];

int n, q;

namespace SA {
int sa[N], rk[N], cnt[N], id[N], oldrk[N], ht[N];

ST st;

inline void prework() {
    int m = n;
    memset(cnt, 0, sizeof(int) * (m + 1));

    for (int i = 1; i <= n; ++i)
        ++cnt[rk[i] = a[i]];

    for (int i = 1; i <= m; ++i)
        cnt[i] += cnt[i - 1];

    for (int i = n; i; --i)
        sa[cnt[rk[i]]--] = i;

    for (int k = 1; k < n; k <<= 1) {
        int tot = 0;

        for (int i = n; i > n - k; --i)
            id[++tot] = i;

        for (int i = 1; i <= n; ++i)
            if (sa[i] > k)
                id[++tot] = sa[i] - k;

        memset(cnt, 0, sizeof(int) * (m + 1));

        for (int i = 1; i <= n; ++i)
            ++cnt[rk[id[i]]];

        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];

        for (int i = n; i; --i)
            sa[cnt[rk[id[i]]]--] = id[i];

        memcpy(oldrk + 1, rk + 1, sizeof(int) * n), tot = 0;

        for (int i = 1; i <= n; ++i) {
            if (oldrk[sa[i]] != oldrk[sa[i - 1]] || oldrk[sa[i] + k] != oldrk[sa[i - 1] + k])
                ++tot;

            rk[sa[i]] = tot;
        }

        if (tot == n)
            break;

        m = tot;
    }

    for (int i = 1, k = 0; i <= n; ++i) {
        if (k)
            --k;

        int j = sa[rk[i] - 1];

        while (i + k <= n && j + k <= n && a[i + k] == a[j + k])
            ++k;

        ht[rk[i]] = k;
    }

    for (int i = 1; i <= n; ++i)
        st[i] = ht[i];

    st.prework(n);
}

inline int LCP(int x, int y) {
    if (x == y)
        return n - x + 1;
    
    x = rk[x], y = rk[y];
    
    if (x > y)
        swap(x, y);
    
    return st.query(x + 1, y);
}
} // namespace SA

inline int querylcp(int x, int y) {
    int i = 0, curx = x, cury = y, len = 0;
    
    while (curx <= n && cury <= n) {
        int nxtx = (i < place[x].size() ? place[x][i] : n + 1),
            nxty = (i < place[y].size() ? place[y][i] : n + 1),
            nlen = min(nxtx - curx, nxty - cury), lcp = SA::LCP(curx, cury);
        
        if (lcp < nlen)
            return len + lcp;
        
        len += nlen;
        
        if (nxtx != curx + nlen || nxty != cury + nlen || nxtx > n || nxty > n)
            return len;
        
        ++len, curx = nxtx + 1, cury = nxty + 1, ++i;
    }
    
    return len;
}

inline int query(int x, int len) {
    x = rid[x];
    int l = x, r = x;

    for (int i = LOGN - 1; ~i; --i) {
        if (l > (1 << i) && st.query(l - (1 << i) + 1, x) >= len)
            l -= 1 << i;

        if (r + (1 << i) <= n && st.query(x + 1, r + (1 << i)) >= len)
            r += 1 << i;
    }

    return r - l + 1;
}

inline bool cmp(const int &x, const int &y) {
    int len = querylcp(x, y),
        cmpx = (x + len <= n ? a[x + len] : -1),
        cmpy = (y + len <= n ? a[y + len] : -1);
    
    if (cmpx == -1 || cmpy == -1)
        return cmpx < cmpy;
    
    for (int it : place[x])
        if (it == x + len) {
            cmpx = 0;
            break;
        }
    
    for (int it : place[y])
        if (it == y + len) {
            cmpy = 0;
            break;
        }
    
    return cmpx < cmpy;
}

signed main() {
    scanf("%d%d%s", &n, &q, str + 1);
    vector<int> lst(11);
    
    for (int i = 1; i <= n; ++i) {
        int idx = str[i] & 15;
        a[i] = lst[idx] ? i - lst[idx] : 0, lst[idx] = i;
    }

    for (int i = n; i; --i) {
        memcpy(nxt[i], nxt[i + 1], sizeof(int) * S), nxt[i][str[i] & 15] = i;
        
        for (int j = 0; j < S; ++j)
            if (nxt[i][j])
                place[i].emplace_back(nxt[i][j]);
        
        sort(place[i].begin(), place[i].end());
    }
    
    SA::prework();
    iota(id + 1, id + 1 + n, 1), stable_sort(id + 1, id + 1 + n, cmp);
    
    for (int i = 1; i <= n; ++i)
        rid[id[i]] = i;

    for (int i = 2; i <= n; ++i)
        st.f[0][i] = querylcp(id[i], id[i - 1]);

    st.prework(n);
    int lstans = 0;
    
    while (q--) {
        int l, r;
        scanf("%d%d", &l, &r);
        l ^= lstans, r ^= lstans;
        printf("%d\n", lstans = query(l, r - l + 1));
    }
    
    return 0;
}

P4070 [SDOI2016] 生成魔咒

给定 \(a_{1 \sim n}\) ,求每个前缀本质不同子串数量。

\(n \le 10^5\)

若直接求解,每次在末尾加数字时,ht 数组的变化是难以维护的。

考虑将字符串反转,则每次相当于在前面加一个数字,相当于加入一个前缀,对 \(ht\) 数组的影响是 \(O(1)\) 的。

使用 set 维护前驱后继,时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7, LOGN = 17;

int a[N];

int n;

namespace SA {
int sa[N], rk[N], cnt[N], id[N], oldrk[N << 1], ht[N];

namespace ST {
int f[LOGN][N];

inline void prework() {
    memcpy(f[0] + 1, ht + 1, sizeof(int) * n);

    for (int j = 1; j <= __lg(n); ++j)
        for (int i = 1; i + (1 << j) - 1 <= n; ++i)
            f[j][i] = min(f[j - 1][i], f[j - 1][i + (1 << (j - 1))]);
}

inline int query(int l, int r) {
    int k = __lg(r - l + 1);
    return min(f[k][l], f[k][r - (1 << k) + 1]);
}
} // namespace ST

inline void prework(int n, int m) {
    memset(cnt, 0, sizeof(int) * (m + 1));

    for (int i = 1; i <= n; ++i)
        ++cnt[rk[i] = a[i]];

    for (int i = 1; i <= m; ++i)
        cnt[i] += cnt[i - 1];

    for (int i = n; i; --i)
        sa[cnt[rk[i]]--] = i;

    for (int k = 1; k < n; k <<= 1) {
        int tot = 0;

        for (int i = n; i > n - k; --i)
            id[++tot] = i;

        for (int i = 1; i <= n; ++i)
            if (sa[i] > k)
                id[++tot] = sa[i] - k;

        memset(cnt, 0, sizeof(int) * (m + 1));

        for (int i = 1; i <= n; ++i)
            ++cnt[rk[id[i]]];

        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];

        for (int i = n; i; --i)
            sa[cnt[rk[id[i]]]--] = id[i];

        memcpy(oldrk + 1, rk + 1, sizeof(int) * n), m = 0;

        for (int i = 1; i <= n; ++i) {
            if (oldrk[sa[i]] != oldrk[sa[i - 1]] || oldrk[sa[i] + k] != oldrk[sa[i - 1] + k])
                ++m;

            rk[sa[i]] = m;
        }

        if (m == n)
            break;
    }

    for (int i = 1, k = 0; i <= n; ++i) {
        if (k)
            --k;

        int j = sa[rk[i] - 1];

        while (i + k <= n && j + k <= n && a[i + k] == a[j + k])
            ++k;

        ht[rk[i]] = k;
    }

    ST::prework();
}
} // namespace SA

signed main() {
    scanf("%d", &n);
    vector<int> vec;

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i), vec.emplace_back(a[i]);

    sort(vec.begin(), vec.end()), vec.erase(unique(vec.begin(), vec.end()), vec.end());

    for (int i = 1; i <= n; ++i)
        a[i] = lower_bound(vec.begin(), vec.end(), a[i]) - vec.begin() + 1;

    reverse(a + 1, a + n + 1), SA::prework(n, vec.size());
    set<int> st;
    ll ans = 0;

    for (int i = n; i; --i) {
        auto it = st.emplace(SA::rk[i]).first;
        int mxlcp = 0;

        if (it != st.begin())
            mxlcp = max(mxlcp, SA::ST::query(*prev(it) + 1, SA::rk[i]));
        
        if (next(it) != st.end())
            mxlcp = max(mxlcp, SA::ST::query(SA::rk[i] + 1, *next(it)));

        printf("%lld\n", ans += n - i + 1 - mxlcp);
    }

    return 0;
}

P1117 [NOI2016] 优秀的拆分

给出字符串 \(S\) ,求有多少形如 \(AABB\) 的子串,其中 \(A, B\) 为字符串。

\(n \le 3 \times 10^4\)

\(f_i\) 表示以 \(i\) 结尾的 \(AA\) 串数量,\(g_i\) 表示以 \(i\) 开头的 \(BB\) 串数量,则答案为 \(\sum f_i \times g_{i + 1}\)

考虑枚举 \(A\) 串的长度 \(len\) ,统计所有长度为 \(len\)\(A\) 串贡献。称下标为 \(len\) 倍数的位置为关键点,则任意一个 \(AA\) 串都覆盖相邻关键点,于是可以枚举相邻的关键点统计。

对于相邻的关键点 \(i, i + len\) ,若被一个 \(AA\) 串覆盖,则这个 \(AA\) 串去掉中间 \(i \sim i + len\) 的部分后必须是 \(A\) 串。

设这个 \(AA\) 串的分界点为 \(k\) ,那么 \(i \sim k\) 应该与 \(i + len\) 开始的一段后缀相等,\(k + 1 \sim i + len\) 应该与 \(i\) 结尾的一段前缀相等。

因此若两个后缀 \(i, i + len\) 的最长公共前缀 \(l_1\) 与两个前缀 \(i - 1, i + len - 1\) 的最长公共后缀 \(l_2\) 的和 \(\ge len\) ,则存在 \(l_! + l_2 - len + 1\)\(AA\) 串覆盖 \(i, i + len\) 。注意 \(l_1\) 要对 \(len\)\(\min\)\(l_2\) 要对 \(len - 1\)\(\min\)

不难发现这是一个区间加一的形式,使用差分即可。求两个后缀的最长公共前缀和求两个前缀的最长公共后缀可以对正反串各建出 SA 解决。

时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 3e4 + 7, LOGN = 15;

int f[N], g[N];
char str[N];

int n;

struct SA {
    int f[LOGN][N];
    int sa[N], rk[N], cnt[N], id[N], oldrk[N << 1], ht[N];

    inline void prework() {
        memset(oldrk + n + 1, 0, sizeof(int) * n);
        int m = 1 << 8;
        memset(cnt, 0, sizeof(int) * (m + 1));

        for (int i = 1; i <= n; ++i)
            ++cnt[rk[i] = str[i]];

        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];

        for (int i = n; i; --i)
            sa[cnt[rk[i]]--] = i;

        for (int k = 1; k < n; k <<= 1) {
            int tot = 0;

            for (int i = n; i > n - k; --i)
                id[++tot] = i;

            for (int i = 1; i <= n; ++i)
                if (sa[i] > k)
                    id[++tot] = sa[i] - k;

            memset(cnt, 0, sizeof(int) * (m + 1));

            for (int i = 1; i <= n; ++i)
                ++cnt[rk[id[i]]];

            for (int i = 1; i <= m; ++i)
                cnt[i] += cnt[i - 1];

            for (int i = n; i; --i)
                sa[cnt[rk[id[i]]]--] = id[i];

            memcpy(oldrk + 1, rk + 1, sizeof(int) * n), m = 0;

            for (int i = 1; i <= n; ++i) {
                if (oldrk[sa[i]] != oldrk[sa[i - 1]] || oldrk[sa[i] + k] != oldrk[sa[i - 1] + k])
                    ++m;

                rk[sa[i]] = m;
            }

            if (m == n)
                break;
        }

        for (int i = 1, k = 0; i <= n; ++i) {
            if (k)
                --k;

            int j = sa[rk[i] - 1];

            while (i + k <= n && j + k <= n && str[i + k] == str[j + k])
                ++k;

            ht[rk[i]] = k;
        }

        memcpy(f[0] + 1, ht + 1, sizeof(int) * n);

        for (int j = 1; j <= __lg(n); ++j)
            for (int i = 1; i + (1 << j) - 1 <= n; ++i)
                f[j][i] = min(f[j - 1][i], f[j - 1][i + (1 << (j - 1))]);
    }

    inline int query(int x, int y) {
        int l = rk[x], r = rk[y];

        if (l > r)
            swap(l, r);

        ++l;
        int k = __lg(r - l + 1);
        return min(f[k][l], f[k][r - (1 << k) + 1]);
    }
} A, B;

signed main() {
    int T;
    scanf("%d", &T);

    while (T--) {
        scanf("%s", str + 1), n = strlen(str + 1);
        A.prework(), reverse(str + 1, str + n + 1), B.prework();
        memset(f + 1, 0, sizeof(int) * n), memset(g + 1, 0, sizeof(int) * n);

        for (int len = 1; len <= n / 2; ++len)
            for (int i = len; i + len <= n; i += len) {
                int l1 = min(A.query(i, i + len), len), 
                    l2 = min(i == 1 ? 0 : B.query(n - (i - 1) + 1, n - (i + len - 1) + 1), len - 1);

                if (l1 + l2 >= len) {
                    int t = l1 + l2 - len + 1;
                    ++f[i + len + l1 - t], --f[i + len + l1];
                    ++g[i - l2], --g[i - l2 + t];
                }
            }

        for (int i = 1; i <= n; ++i)
            f[i] += f[i - 1], g[i] += g[i - 1];

        ll ans = 0;

        for (int i = 1; i < n; ++i)
            ans += f[i] * g[i + 1];

        printf("%lld\n", ans);
    }

    return 0;
}

P5115 Check,Check,Check one two!

给定一个长为 \(n\) 的字符串 \(S\) ,定义:

  • \(lcp(i, j)\)\(S[i, n]\)\(S[j, n]\) 的最长公共前缀。
  • \(lcs(i, j)\)\(S[1, i]\)\(S[1, j]\) 的最长公共后缀。

求:

\[\left( \sum_{1 \le i < j \le n} [lcp(i, j) \le k_1] [lcs(i, j) \le k_2] \times lcp(i, j) \times lcs(i, j) \right) \bmod 2^{64} \]

\(n \le 10^5\)

发现 \(lcs(i, j)\)\(lcp(i, j)\) 拼起来就是一个长为 \(lcs(i, j) + lcp(i, j) - 1\) 的极长公共子串,考虑在 \((i - lcs(i, j) + 1, j - lcs(i, j) + 1)\) 处统计贡献。

枚举 \(i, j\) ,若 \(S_{i - 1} \ne S_{j - 1}\) ,则 \(S[i, i + lcp(i, j) - 1]\) 产生贡献。进一步可以发现贡献仅与 \(L = lcp(i, j)\) 有关,记:

\[\begin{align} f(L) &= \sum_{i = 1}^L i \times (L - i + 1) [i \le k_1] [n - i + 1 \le k_2] \\ &= \sum_{n - k_2 + 1}^{k_1} (L + 1) \times i - i^2 \end{align} \]

不难 \(O(n)\) 预处理 \(f(1 \sim n)\)

但是 \(S_{i - 1} \ne S_{j - 1}\) 的位置还是很多,补集转化为无限制的答案减去 \(S_{i - 1} = S_{j - 1}\) 的答案即可。剩下的根据 \(lcp\) 即为区间 \(ht\) 最小值的性质,对 \(ht\) 数组做扫描线,维护单调栈即可。

时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
typedef unsigned long long ull;
using namespace std;
const int N = 1e5 + 7, S = 26;

ull f[N];
int sta[N], g[N];
char str[N];

int n, k1, k2;

inline ull S1(int n) {
    return 1ull * n * (n + 1) / 2;
}

inline ull S2(int n) {
    return 1ull * n * (n + 1) * (n * 2 + 1) / 6;
}

namespace SA {
int sa[N], rk[N], cnt[N], id[N], oldrk[N << 1], ht[N];

inline void prework() {
    memset(oldrk + n + 1, 0, sizeof(int) * n);
    int m = 1 << 8;
    memset(cnt + 1, 0, sizeof(int) * m);

    for (int i = 1; i <= n; ++i)
        ++cnt[rk[i] = str[i]];

    for (int i = 1; i <= m; ++i)
        cnt[i] += cnt[i - 1];

    for (int i = n; i; --i)
        sa[cnt[rk[i]]--] = i;

    for (int k = 1; k < n; k <<= 1) {
        int tot = 0;

        for (int i = n; i > n - k; --i)
            id[++tot] = i;

        for (int i = 1; i <= n; ++i)
            if (sa[i] > k)
                id[++tot] = sa[i] - k;

        memset(cnt + 1, 0, sizeof(int) * m);

        for (int i = 1; i <= n; ++i)
            ++cnt[rk[id[i]]];

        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];

        for (int i = n; i; --i)
            sa[cnt[rk[id[i]]]--] = id[i];

        memcpy(oldrk + 1, rk + 1, sizeof(int) * n), m = 0;

        for (int i = 1; i <= n; ++i) {
            if (oldrk[sa[i]] != oldrk[sa[i - 1]] || oldrk[sa[i] + k] != oldrk[sa[i - 1] + k])
                ++m;

            rk[sa[i]] = m;
        }

        if (m == n)
            break;
    }

    for (int i = 1, k = 0; i <= n; ++i) {
        if (k)
            --k;

        int j = sa[rk[i] - 1];

        while (i + k <= n && j + k <= n && str[i + k] == str[j + k])
            ++k;

        ht[rk[i]] = k;
    }
}
} // namespace SA

inline ull calc(char c) {
    ull ans = 0, res = 0;

    for (int i = 2, top = 0; i <= n; ++i) {
        int now = c ? str[SA::sa[i - 1] - 1] == c : 1;

        while (top && SA::ht[sta[top]] >= SA::ht[i])
            res -= f[SA::ht[sta[top]]] * g[top], now += g[top--];

        sta[++top] = i, g[top] = now, res += f[SA::ht[i]] * now;

        if (!c || str[SA::sa[i] - 1] == c)
            ans += res;
    }

    return ans;
}

signed main() {
    scanf("%s%d%d", str + 1, &k1, &k2), n = strlen(str + 1);

    for (int i = 1; i <= n; ++i) {
        int l = max(1, i - k2 + 1), r = min(i, k1);

        if (l <= r)
            f[i] = (S1(r) - S1(l - 1)) * (i + 1) - (S2(r) - S2(l - 1));
    }

    SA::prework();
    ull ans = calc(0);

    for (int i = 0; i < S; ++i)
        ans -= calc('a' + i);

    printf("%llu", ans);
    return 0;
}

CF653F Paper task

给出一个长度为 \(n\) 的括号串,求本质不同的合法括号串数量。

\(n \le 5 \times 10^5\)

先不考虑本质不同的限制,将左右括号记为 \(1, -1\) ,记 \(s_i\) 为前缀和,则 \(S[l, r]\) 合法当且仅当 \([s_r = s_{l - 1}] \and [s_{l - 1} \le \min_{i = l}^r s_i]\)

对于一个 \(l\) ,可以二分出满足 \(s_{l - 1} \le \min_{i = l}^r s_i\) 的最大右端点,则合法的右端点只要对 \(s\) 的每个值记录位置然后二分查找即可。

接下来考虑本质不同的限制,后缀排序后只要对减去排名相邻的两个串的公共合法前缀数量即可,求法是类似的。

时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 5e5 + 7, LOGN = 19;

vector<int> vec[N << 1];

int s[N];
char str[N];

int n;

namespace SA {
int sa[N], rk[N], cnt[N], id[N], oldrk[N << 1], ht[N];

inline void prework() {
    memset(oldrk + n + 1, 0, sizeof(int) * n);
    int m = 1 << 8;
    memset(cnt + 1, 0, sizeof(int) * m);

    for (int i = 1; i <= n; ++i)
        ++cnt[rk[i] = str[i]];

    for (int i = 1; i <= m; ++i)
        cnt[i] += cnt[i - 1];

    for (int i = n; i; --i)
        sa[cnt[rk[i]]--] = i;

    for (int k = 1; k < n; k <<= 1) {
        int tot = 0;

        for (int i = n; i > n - k; --i)
            id[++tot] = i;

        for (int i = 1; i <= n; ++i)
            if (sa[i] > k)
                id[++tot] = sa[i] - k;

        memset(cnt + 1, 0, sizeof(int) * m);

        for (int i = 1; i <= n; ++i)
            ++cnt[rk[id[i]]];

        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];

        for (int i = n; i; --i)
            sa[cnt[rk[id[i]]]--] = id[i];

        memcpy(oldrk + 1, rk + 1, sizeof(int) * n), m = 0;

        for (int i = 1; i <= n; ++i) {
            if (oldrk[sa[i]] != oldrk[sa[i - 1]] || oldrk[sa[i] + k] != oldrk[sa[i - 1] + k])
                ++m;

            rk[sa[i]] = m;
        }

        if (m == n)
            break;
    }

    for (int i = 1, k = 0; i <= n; ++i) {
        if (k)
            --k;

        int j = sa[rk[i] - 1];

        while (i + k <= n && j + k <= n && str[i + k] == str[j + k])
            ++k;

        ht[rk[i]] = k;
    }
}
} // namespace SA

namespace ST {
int f[LOGN][N];

inline void prework() {
    memcpy(f[0] + 1, s + 1, sizeof(int) * n);

    for (int j = 1; j <= __lg(n); ++j)
        for (int i = 1; i + (1 << j) - 1 <= n; ++i)
            f[j][i] = min(f[j - 1][i], f[j - 1][i + (1 << (j - 1))]);
}

inline int query(int l, int r) {
    int k = __lg(r - l + 1);
    return min(f[k][l], f[k][r - (1 << k) + 1]);
}
} // namespace ST

inline int query(int L, int R) {
    int l = L, r = R;
    R = L - 1;

    while (l <= r) {
        int mid = (l + r) >> 1;

        if (ST::query(L, mid) >= s[L - 1])
            R = mid, l = mid + 1;
        else
            r = mid - 1;
    }

    return upper_bound(vec[s[L - 1]].begin(), vec[s[L - 1]].end(), R) - 
        lower_bound(vec[s[L - 1]].begin(), vec[s[L - 1]].end(), L);
}

signed main() {
    scanf("%d%s", &n, str + 1);
    vec[s[0] = n].emplace_back(0);

    for (int i = 1; i <= n; ++i)
        vec[s[i] = s[i - 1] + (str[i] == '(' ? 1 : -1)].emplace_back(i);

    SA::prework(), ST::prework();
    ll ans = 0;

    for (int i = 1; i <= n; ++i)
        ans += query(i, n);

    for (int i = 2; i <= n; ++i)
        ans -= query(SA::sa[i], SA::sa[i] + SA::ht[i] - 1);

    printf("%lld", ans);
    return 0;
}

P9482 [NOI2023] 字符串

给定长度为 \(n\) 的字符串 \(S\)\(q\) 次询问,每次给出 \(i, r\) ,求有多少 \(l \in [1, r]\) 满足 \(S[i, i + l - 1] < S[i + l, i + 2l - 1]^R\)

\(n, q \le 10^5\)

子串的结构不好处理,考虑转化为前后缀结构。设 \(A_i, B_i\) 表示 \(i\) 向左右扩展到边界得到的字符串,将限制弱化为 \(A_i < B_{i + 2l - 1}\) ,则需要减去 \(S[i, i + 2l - 1]\) 为回文串且 \(A_{i + 2l} < B_{i - 1}\)\(l\) 的数量。

先考虑求 \(A_i < B_{i + 2l - 1}\) 的数量。对 \(S + x + S^R + y\) 建立 SA 结构,其中 \(x\)\(y\) 为小于所有字符的任意非字符集内的字符,需要保证 \(x > y\) 。离线询问后按排名从大到小扫描线,遇到 \(B\) 则标记对应位置,遇到 \(A\) 则回答询问,查询形如区间内被标记的奇数或偶数位置数量,开两棵 BIT 维护即可。

再考虑求 \(S[i, i + 2l - 1]\) 为回文串且 \(A_{i + 2l} < B_{i - 1}\)\(l\) 的数量。首先用 Manacher 求出每个 \((i, i + 1)\) 分隔处的回文半径 \(p_i\) ,则 \(S[i - p_i + 1, i + p_i]\) 为极长回文串。而 \(A_{i + 2l} < B_{i - 1}\) 这个限制对于某个固定中心的所有回文串都是等价的,即对于 \((i, i + 1)\) 为中心的回文串都等价于判定 \(A_{i + p_i + 1} < B_{i - p_i}\) 。而 \(S_{i + p_i + 1} \ne S_{i - p_i}\) ,因此判定条件即为 \(S_{i + p_i + 1} < S_{i - p_i}\) ,那么可以线性预处理出每个合法的回文中心。

考虑对于一个以 \((i_0, i_0 + 1)\) 为中心、半径为 \(r_0\) 的极长回文串,在什么时候会被 \((i, r)\) 的询问统计到,条件即为 \(i \in [i_0 - r_0 + 1, i]\)\(i + r > i_0\) ,此时二维数点的形式就很明显了。

时间复杂度 \(O((n + q) \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 7;

struct BIT {
    int c[N];

    int n;

    inline void prework(int _n) {
        memset(c + 1, 0, sizeof(int) * (n = _n));
    }

    inline void update(int x, int k) {
        for (; x <= n; x += x & -x)
            c[x] += k;
    }

    inline int query(int x) {
        int res = 0;

        for (; x; x -= x & -x)
            res += c[x];

        return res;
    }
};

pair<int, int> ask[N];

int ans[N];
char str[N];

int n, q;

namespace Method1 {
BIT bit[2];

vector<tuple<int, int, int> > qry[N];

char t[N];

namespace SA {
int sa[N], rk[N], cnt[N], id[N], oldrk[N];

inline void prework(char *str, int n) {
    memset(oldrk + n + 1, 0, sizeof(int) * n);
    int m = 1 << 8;
    memset(cnt + 1, 0, sizeof(int) * m);

    for (int i = 1; i <= n; ++i)
        ++cnt[rk[i] = str[i]];

    for (int i = 1; i <= m; ++i)
        cnt[i] += cnt[i - 1];

    for (int i = n; i; --i)
        sa[cnt[rk[i]]--] = i;

    for (int k = 1; k < n; k <<= 1) {
        int tot = 0;

        for (int i = n; i > n - k; --i)
            id[++tot] = i;

        for (int i = 1; i <= n; ++i)
            if (sa[i] > k)
                id[++tot] = sa[i] - k;

        memset(cnt + 1, 0, sizeof(int) * m);

        for (int i = 1; i <= n; ++i)
            ++cnt[rk[id[i]]];

        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];

        for (int i = n; i; --i)
            sa[cnt[rk[id[i]]]--] = id[i];

        memcpy(oldrk + 1, rk + 1, sizeof(int) * n), m = 0;

        for (int i = 1; i <= n; ++i) {
            if (oldrk[sa[i]] != oldrk[sa[i - 1]] || oldrk[sa[i] + k] != oldrk[sa[i - 1] + k])
                ++m;

            rk[sa[i]] = m;
        }

        if (m == n)
            break;
    }
}
} // namespace SA

inline void solve() {
    int m = 0;

    for (int i = 1; i <= n; ++i)
        t[++m] = str[i];

    t[++m] = 'a' - 1;

    for (int i = n; i; --i)
        t[++m] = str[i];

    t[++m] = 'a' - 2, SA::prework(t, m = n * 2 + 2);

    for (int i = 1; i <= q; ++i)
        qry[SA::rk[ask[i].first]].emplace_back(ask[i].first + 1, ask[i].first + ask[i].second * 2 - 1, i);

    bit[0].prework(n), bit[1].prework(n);

    for (int i = m; i; --i) {
        int p = SA::sa[i];

        if (p <= n) {
            for (auto it : qry[i])
                ans[get<2>(it)] += bit[get<0>(it) & 1].query(get<1>(it)) -
                    bit[get<0>(it) & 1].query(get<0>(it) - 1);

            qry[i].clear();
        } else if (p != n + 1 && p != m)
            bit[(m - p) & 1].update(m - p, 1);
    }
}
} // namespace Method1

namespace Method2 {
BIT bit;

vector<pair<int, int> > upd[N], qry[N];

int p[N];
char t[N];

inline void solve() {
    int m = 0;

    for (int i = 1; i <= n; ++i)
        t[++m] = '#', t[++m] = str[i];

    t[++m] = '#';

    for (int i = 1, mid = 0, r = 0; i <= m; ++i) {
        p[i] = (i <= r ? min(p[mid * 2 - i], r - i + 1) : 1);

        while (i - p[i] && i + p[i] <= m && t[i - p[i]] == t[i + p[i]])
            ++p[i];

        if (i + p[i] - 1 > r)
            mid = i, r = i + p[i] - 1;
    }

    for (int i = 1; i < n; ++i) {
        int r = p[i * 2 + 1] / 2;

        if (r < i && str[i + r + 1] < str[i - r])
            upd[i - r + 1].emplace_back(i + 1, 1), upd[i + 1].emplace_back(i + 1, -1);
    }

    for (int i = 1; i <= q; ++i)
        qry[ask[i].first].emplace_back(ask[i].first + ask[i].second, i);

    bit.prework(n);

    for (int i = 1; i <= n; ++i) {
        for (auto it : upd[i])
            bit.update(it.first, it.second);

        for (auto it : qry[i])
            ans[it.second] -= bit.query(it.first);

        upd[i].clear(), qry[i].clear();
    }
}
} // namespace Method2

signed main() {
    int testid, T;
    scanf("%d%d", &testid, &T);

    while (T--) {
        scanf("%d%d%s", &n, &q, str + 1);

        for (int i = 1; i <= q; ++i)
            scanf("%d%d", &ask[i].first, &ask[i].second);

        memset(ans + 1, 0, sizeof(int) * q);
        Method1::solve(), Method2::solve();

        for (int i = 1; i <= q; ++i)
            printf("%d\n", ans[i]);
    }

    return 0;
}
posted @ 2024-12-22 17:14  wshcl  阅读(23)  评论(0)    收藏  举报