洛谷P6088 [JSOI2015] 字符串树 题解 LCA+可持久化字典树

题目链接:https://www.luogu.com.cn/problem/P6088

解题思路:

每个节点维护根节点到它的路径中所有字符串对应的字典树。

使用可持久化字典树维护。

\((x, y)\) 路径上以 \(s\) 为前缀的字符串个数 =

结点 \(x\) 维护的字典树中前缀 \(s\) 的个数 \(+\) 结点 \(y\) 维护的字典树中前缀 \(s\) 的个数 \(- 2 \times\) 结点 \(lca(x,y)\) 维护的字典树中前缀 \(s\) 的个数。

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5;

struct Node {
    int son[26], cnt;
} tr[maxn*10];
int rt[maxn] = { 0, 1 }, idx = 1;

int cpy_node(int u) {
    tr[++idx] = tr[u];
    return idx;
}

int n, q, fa[maxn][17], dep[maxn];
struct Edge {
    int v;
    string s;
};
vector<Edge> g[maxn];

void ins(int pu, int u, string &s) {
    for (int i = 0; s[i]; i++) {
        for (int j = 0; j < 26; j++)
            tr[u].son[j] = tr[pu].son[j];
        int x = s[i] - 'a';
        tr[u].son[x] = cpy_node(tr[pu].son[x]);
        pu = tr[pu].son[x];
        u = tr[u].son[x];
        tr[u].cnt++;
    }
}

void dfs(int u, int p) {
    fa[u][0] = p;
    dep[u] = dep[p] + 1;
    for (auto e : g[u]) {
        int v = e.v;
        string s = e.s;
        if (v != p) {
            rt[v] = ++idx;
            ins(rt[u], rt[v], s);
            dfs(v, u);
        }
    }
}

int lca(int x, int y) {
    if (dep[x] < dep[y]) swap(x, y);
    for (int i = 16; i >= 0; i--)
        if (dep[ fa[x][i] ] >= dep[y])
            x = fa[x][i];
    if (x == y) return x;
    for (int i = 16; i >= 0; i--)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}

int query(int u, string &s) {
    for (int i = 0; s[i]; i++) {
        int x = s[i] - 'a';
        assert(0 <= x && x < 26);
        u = tr[u].son[x];
    }
    return tr[u].cnt;
}

int main() {
    cin >> n;
    for (int i = 1; i < n; i++) {
        int u, v;
        string s;
        cin >> u >> v >> s;
        g[u].push_back({v, s});
        g[v].push_back({u, s});
    }
    dfs(1, 0);
    for (int i = 1; i <= 16; i++)
        for (int u = 1; u <= n; u++)
            fa[u][i] = fa[ fa[u][i-1] ][i-1];
    cin >> q;
    while (q--) {
        int x, y, z;
        string s;
        cin >> x >> y >> s;
        z = lca(x, y);
        int ans = query(rt[x], s) + query(rt[y], s) - 2 * query(rt[z], s);
        printf("%d\n", ans);
    }
    return 0;
}
posted @ 2025-05-16 21:06  quanjun  阅读(13)  评论(0)    收藏  举报