P4482(树链剖分、SAM)

我来复读 @万弘 的题解。

约定下文的 \(\text{LCS}(i,j)\) 指前缀 \(i,j\) 的最长公共后缀。

考虑建出 \(\text{parent tree}\),那么对于单次询问 \([l,r]\),只需要查询位置 \(x\),使得 \(x\in[l,r),\text{LCS}(x,r)\geq x-l+1\),且 \(x\) 是满足条件中最大的即可。容易将 \(\text{LCS}\) 转化成树上的 \(\text{LCA}\)\(len\) 值,然后这就变成了一个树上点对问题。由于点分治做不了 \(\text{LCA}\),考虑链分治,那么将 \(\text{parent tree}\) 重链剖分,并把询问挂在进入每条重链的点上,设为 \(S\)。此时所有点被分成了三部分:

  • \(S\) 子树中的点。
  • \(S\) 同一重链并在 \(S\) 上方的点。
  • 第二类型的点的轻子树中的点。

对于第一类点,我们线段树合并维护 edp,每次线段树上二分查询最大的 \(x\) 使 \(x\leq \text{LCS}(x,r)+l-1\)

对于第二类和第三类点,我们在重链上从浅到深扫描维护这些点的信息。具体的,在线段树上叶子 \(x\) 处维护 \(\text{LCS}(x,r)-x\),其它点维护区间 \(\min\)。然后查询时在线段树上二分 \(x\)。由于轻子树的总大小为 \(O(n\log n)\),重链点数为 \(O(n)\),所以总复杂度为 \(O(n\log^2 n)\)

值得注意的是这种链分治处理点对问题的做法在其它问题上依然适用。

#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define db double
#define ldb long double
#define pb push_back
#define mp make_pair
#define pii pair<int, int>
using namespace std;
inline int read() {
    int x = 0; bool op = 0;
    char c = getchar();
    while(!isdigit(c))op |= (c == '-'), c = getchar();
    while(isdigit(c))x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
    return op ? -x : x;
}
const int N = 4e5 + 10;
const int INF = 1e9;
int n, Q;
char s[N];
int ans[N];
struct Endpos {
    int tot;
    struct Node {
        int ls, rs, sz;
    }nd[N * 25];
    int update(int cur, int l, int r, int x, int val) {
        int p = ++tot; nd[p] = nd[cur];
        if(l == r) {nd[p].sz += val; return p;}
        int mid = l + r >> 1;
        if(x <= mid)nd[p].ls = update(nd[cur].ls, l, mid, x, val);
        else nd[p].rs = update(nd[cur].rs, mid + 1, r, x, val);
        nd[p].sz = nd[nd[p].ls].sz + nd[nd[p].rs].sz;
        return p;
    }
    int merge(int x, int y, int l, int r) {
        if(x == 0 || y == 0)return x + y;
        int p = ++tot;
        if(l == r) {nd[p].sz = nd[x].sz + nd[y].sz; return p;}
        int mid = l + r >> 1;
        nd[p].ls = merge(nd[x].ls, nd[y].ls, l, mid);
        nd[p].rs = merge(nd[x].rs, nd[y].rs, mid + 1, r);
        nd[p].sz = nd[nd[p].ls].sz + nd[nd[p].rs].sz;
        return p;
    }
    int query(int cur, int l, int r, int bd) {
        if(bd <= 0 || cur == 0 || nd[cur].sz <= 0)return 0;
        if(l == r)return l;
        int mid = l + r >> 1, res = 0;
        if(bd <= mid)return query(nd[cur].ls, l, mid, bd);
        else {
            int res = 0;
            res = max(res, query(nd[cur].rs, mid + 1, r, bd));
            if(res == 0)
            res = max(res, query(nd[cur].ls, l, mid, bd));
            return res;
        }
    }
}edp;
struct Sgt_Tree {
    int mx[N << 2];
    void build(int k, int l, int r) {
        mx[k] = INF;
        if(l == r)return ;
        int mid = l + r >> 1;
        build(k << 1, l, mid); build(k << 1 | 1, mid + 1, r);
        return ;
    }
    void update(int k, int l, int r, int x, int val) {
        if(l == r)return mx[k] = val, void();
        int mid = l + r >> 1;
        if(x <= mid)update(k << 1, l, mid, x, val);
        else update(k << 1 | 1, mid + 1, r, x, val);
        mx[k] = min(mx[k << 1], mx[k << 1 | 1]);
        return ;
    }
    int query(int k, int l, int r, int x, int bd) {
        if(x <= 0 || mx[k] > bd)return 0;
        if(l == r)return l;
        int mid = l + r >> 1;
        if(x <= mid)return query(k << 1, l, mid, x, bd);
        else {
            int res = 0;
            res = max(res, query(k << 1 | 1, mid + 1, r, x, bd));
            if(res == 0)
            res = max(res, query(k << 1, l, mid, x, bd));
            return res;
        }
    }
}sgt;
struct SAM {
    int tot = 1, lst = 1, len[N], fa[N], son[N][26], id[N], ed[N];
    void extend(int c, int x) {
        int cur = ++tot, p = lst; lst = cur;
        len[cur] = len[p] + 1; id[x] = cur; ed[cur] = x;
        while(p && son[p][c] == 0)son[p][c] = cur, p = fa[p];
        if(p == 0)return fa[cur] = 1, void();
        int q = son[p][c];
        if(len[q] == len[p] + 1)return fa[cur] = q, void();
        int cl = ++tot; len[cl] = len[p] + 1; fa[cl] = fa[q];
        memcpy(son[cl], son[q], sizeof(son[q]));
        fa[cur] = fa[q] = cl;
        while(p && son[p][c] == q)son[p][c] = cl, p = fa[p];
        return ;
    }
    vector<int> nb[N];
    int rt[N], hson[N], sz[N], top[N];
    void dfs1(int u) {
        sz[u] = 1;
        for(int v : nb[u]) {
            dfs1(v);
            sz[u] += sz[v];
            if(hson[u] == 0 || sz[v] > sz[hson[u]])hson[u] = v;
            rt[u] = edp.merge(rt[u], rt[v], 1, n);
        }
        return ;
    }
    void dfs2(int u, int t) {
        top[u] = t;
        if(hson[u])dfs2(hson[u], t);
        for(int v : nb[u])if(v ^ hson[u])dfs2(v, v);
        return ;
    }
    void build() {
        for(int i = 2; i <= tot; i++)nb[fa[i]].pb(i);
        for(int i = 1; i <= tot; i++)if(ed[i])rt[i] = edp.update(rt[i], 1, n, ed[i], 1);
        dfs1(1); dfs2(1, 1);
        return ;
    }
    struct Qry {
        int id, l, r;
        Qry() {}
        Qry(int id, int l, int r):id(id), l(l), r(r) {}
    };
    vector<Qry> q[N];
    void push(int d, int l, int r) {
        int p = id[r]; 
        while(p)q[p].pb(Qry(d, l, r)), p = fa[top[p]];
        return ;
    }
    int hs;
    void dfslgt(int u, int l) {
        if(ed[u])sgt.update(1, 1, n, ed[u], ed[u] - l);
        for(int v : nb[u])if(v ^ hs)dfslgt(v, l);
        return ;
    }
    void solve(int u) {
        for(int p = u; p; p = hson[p]) {
            hs = hson[p]; dfslgt(p, len[p]); hs = 0;
            for(auto t : q[p]) {
                int d = t.id, res = edp.query(rt[p], 1, n, min(t.r - 1, len[p] + t.l - 1));
                if(res >= t.l)ans[d] = max(ans[d], res - t.l + 1);
                res = sgt.query(1, 1, n, t.r - 1, t.l - 1);
                if(res >= t.l)ans[d] = max(ans[d], res - t.l + 1);
            }
        }
        for(int p = u; p; p = hson[p]) {
            hs = hson[p]; dfslgt(p, -INF); hs = 0;
        }
        for(int p = u; p; p = hson[p]) {
            for(int v : nb[p])if(v ^ hson[p])solve(v);
        }
        return ;
    }
}sam;
int main() {
    scanf("%s", s + 1);
    n = strlen(s + 1); Q = read();
    for(int i = 1; i <= n; i++) {
        sam.extend(s[i] - 'a', i);
    }
    sam.build();
    for(int i = 1; i <= Q; i++) {
        int l = read(), r = read();
        sam.push(i, l, r);
    }
    sgt.build(1, 1, n); sam.solve(1);
    for(int i = 1; i <= Q; i++)printf("%d\n", ans[i]);
    return 0;
}
posted @ 2022-07-19 17:39  yllcm  阅读(71)  评论(0)    收藏  举报