【Luogu P4218】「CTSC2010」珠宝商
Description
给出一棵包含 \(n\) 个节点的树,每个节点上都有一个字符。除此之外,还会给出一个长度为 \(m\) 的文本串。
求树上所有路径,经过的所有字符按顺序连接形成的字符串,在文本串中的出现次数之和。
数据范围:\(1 \leq n, m \leq 5 \times 10^4\),给出的字符均为小写字母。
时空限制:\(8000 \ \mathrm{ms} / 500 \ \mathrm{MiB}\)。
Solution
运用点分治统计路径答案。对于当前大小为 \(\mathbf{size}\) 的分治块,有两种内核不同的算法。
算法一:\(\mathcal{O}(\mathbf{size}^2)\)
将原串的 SAM 建出。
直接计算当前分治块的所有路径对答案的贡献。
枚举当前分治块的每一个点,计算以当前点为起点的所有路径对答案的贡献。
可以从该起点开始搜索,维护起点到当前点所形成的字符串在 SAM 上的状态,则当前路径对答案的贡献即为该状态所对应的 \(\mathrm{endpos}\) 集合大小。
算法二:\(\mathcal{O}(\mathbf{size} + m)\)
将原串、反串的 SAM 建出。
考虑计算跨越当前分治重心 \(u\) 的所有路径对答案的贡献,其他情况递归求解。
对于一个跨越当前分治重心 \(u\) 的路径 \(x \to u \to y\),可以考虑将其拆成 \(x \to u\) 的前缀段与 \(u \to y\) 的后缀段。
考虑计算出两个数组 \(\mathrm{le}(i), \mathrm{rg}(i)\) 分别表示,在模式串中以位置 \(i\) 为结尾的前缀段个数、以位置 \(i\) 为开头的后缀段个数。则跨越当前分治重心的所有路径对答案的贡献即为 \(\sum_{i = 1}^m \mathrm{le}(i) \cdot \mathrm{rg}(i)\)。
但由于这样做可能会额外计算到来自同一颗子树的前缀串、后缀串组合的贡献。所以还需要对所有的子分治块进行相似的统计,减去错误的贡献即可。
可以从分治重心开始搜索,维护当前点到分治重心所形成的字符串在正串、反串 SAM 上的状态,则在正串 SAM 里的状态所对应 \(\mathrm{endpos}\) 集合里的 \(\mathrm{le}(i)\),反串 SAM 里的状态所对应 \(\mathrm{endpos}\) 集合里的 \(\mathrm{rg}(i)\) 都要加一。在对应状态上打上一个标记,最后自上而下传递一遍即可。
在维护当前点到分治重心所形成的字符串在 SAM 上的状态时,涉及到在一个状态的开头插入字符,SAM 也是可以做的。具体地,在加字符之前:
- 若当前串的长度等于当前状态的 \(\mathrm{maxl}\),则相当于要在 parent 树上向一个儿子节点走。在建 parent 树的时候,预处理每一个状态前加一个字符,能走到哪个儿子即可。
- 若当前串的长度小于当前状态的 \(\mathrm{maxl}\),则相当于要考虑当前状态是否可以容纳新串,只需判断新加的字符是否与原串对应位置上的字符匹配即可。
算法一 & 二
在下文的探讨中,视 \(n, m\) 同阶。
注意到算法一在 \(\mathbf{size}\) 较小时表现较好,算法二在 \(\mathbf{size}\) 较大时表现较好。不妨将两种算法相结合。
针对当前分治块的大小 \(\mathbf{size}\) 进行平衡规划。取阈值 \(B\),当 \(\mathbf{size} \leq B\) 时使用算法一,\(\mathbf{size} > B\) 时使用算法二。
注意到在点分治的过程中,只需要在点分树上遍历 \(\mathcal{O}(\frac{n}{B})\) 个节点,即可保证分治出的最大子树大小为 \(\mathcal{O}(B)\) 的,并且这些叶子节点的个数也为 \(\mathcal{O}(\frac{n}{B})\) 的。因此:
- 算法一开销:\(\mathcal{O}(B^2) \cdot \mathcal{O}(\frac{n}{B}) = \mathcal{O}(nB)\)。
- 算法二开销:\(\mathcal{O}(n) \cdot \mathcal{O}(\frac{n}{B}) + \mathcal{O}(n \log \frac{n}{B}) = \mathcal{O}(\frac{n^2}{B})\)。
取 \(B = \sqrt{n}\),可以取得理论最优复杂度 \(\mathcal{O}(n \sqrt{n})\)。取得最优运行效率,可能需要进行微调。
需要注意的是,对满足 \(\mathbf{size} > B\) 的分治块运用了算法二后,要对满足 \(\mathbf{size}' \leq B\) 的子分治块进行类算法一的容斥。否则会被菊花图卡掉。
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>
using std::vector;
template <class T>
inline void read(T &x) {
static char s;
while (s = getchar(), s < '0' || s > '9');
x = s - '0';
while (s = getchar(), s >= '0' && s <= '9') x = x * 10 + s - '0';
}
const int N = 50100;
int n, m, Bn;
int tot, head[N], ver[N * 2], Next[N * 2];
void add_edge(int u, int v) {
ver[++ tot] = v; Next[tot] = head[u]; head[u] = tot;
}
char a[N], tex[N];
int A_sz, A_rt;
int sz[N], mp[N];
bool ban[N];
int temp_sz[N];
int Get_sz(int u, int fu) {
int cnt = 1;
for (int i = head[u]; i; i = Next[i]) {
int v = ver[i];
if (v == fu || ban[v]) continue;
cnt += Get_sz(v, u);
}
return cnt;
}
void Get_rt(int u, int fu) {
sz[u] = 1, mp[u] = 0;
for (int i = head[u]; i; i = Next[i]) {
int v = ver[i];
if (v == fu || ban[v]) continue;
Get_rt(v, u);
sz[u] += sz[v];
if (sz[v] > mp[u]) mp[u] = sz[v];
}
if (A_sz - sz[u] > mp[u]) mp[u] = A_sz - sz[u];
if (A_rt == 0 || mp[u] < mp[A_rt]) A_rt = u;
}
struct SAM {
static const int SIZE = N * 2;
int strL, s[N];
int cT = 1, Last = 1;
struct node {
int trans[26];
int link, maxl;
} t[SIZE];
int pos[SIZE], cnt[SIZE];
vector<int> go[SIZE];
int net[SIZE][26];
int extend(int c) {
int p = Last,
np = Last = ++ cT;
s[pos[np] = ++ strL] = c;
cnt[np] ++;
t[np].maxl = t[p].maxl + 1;
for (; p && t[p].trans[c] == 0; p = t[p].link) t[p].trans[c] = np;
if (!p) {
t[np].link = 1;
} else {
int q = t[p].trans[c];
if (t[q].maxl == t[p].maxl + 1) {
t[np].link = q;
} else {
int nq = ++ cT;
t[nq] = t[q], t[nq].maxl = t[p].maxl + 1, pos[nq] = pos[q];
t[np].link = t[q].link = nq;
for (; p && t[p].trans[c] == q; p = t[p].link) t[p].trans[c] = nq;
}
}
return np;
}
void scout(int u) {
for (int v : go[u]) {
net[u][s[pos[v] - t[u].maxl]] = v;
scout(v);
cnt[u] += cnt[v];
}
}
void build() {
for (int u = 2; u <= cT; u ++) go[t[u].link].push_back(u);
scout(1);
}
int tag[SIZE];
void remake() {
for (int i = 1; i <= cT; i ++) tag[i] = 0;
}
void mark(int u, int fu, int p, int curL) {
if (curL == t[p].maxl) p = net[p][a[u] - 'a'];
else if (s[pos[p] - curL] != a[u] - 'a') p = 0;
if (p == 0) return;
tag[p] ++;
for (int i = head[u]; i; i = Next[i]) {
int v = ver[i];
if (v == fu || ban[v]) continue;
mark(v, u, p, curL + 1);
}
}
void spread(int u) {
tag[u] += tag[t[u].link];
for (int v : go[u]) spread(v);
}
} A, B;
int posA[N], posB[N];
long long ans;
void path_calc(int u, int fu, int p, int opt) {
p = A.t[p].trans[a[u] - 'a'];
if (p == 0) return;
ans += opt * A.cnt[p];
for (int i = head[u]; i; i = Next[i]) {
int v = ver[i];
if (v == fu || ban[v]) continue;
path_calc(v, u, p, opt);
}
}
void work_sm(int u, int fu) {
path_calc(u, 0, 1, 1);
for (int i = head[u]; i; i = Next[i]) {
int v = ver[i];
if (v == fu || ban[v]) continue;
work_sm(v, u);
}
}
void work_bg(int u, int fu, int opt) {
A.remake(), B.remake();
if (fu == 0)
A.mark(u, 0, 1, 0),
B.mark(u, 0, 1, 0);
else
A.mark(u, fu, A.t[1].trans[a[fu] - 'a'], 1),
B.mark(u, fu, B.t[1].trans[a[fu] - 'a'], 1);
A.spread(1), B.spread(1);
for (int i = 1; i <= m; i ++) ans += 1ll * opt * A.tag[posA[i]] * B.tag[posB[i]];
}
int A_son;
void plus_bg(int u, int fu, int p, int curL) {
if (curL == A.t[p].maxl) p = A.net[p][a[u] - 'a'];
else if (tex[A.pos[p] - curL] != a[u]) p = 0;
if (p == 0) return;
path_calc(A_son, A_rt, p, -1);
for (int i = head[u]; i; i = Next[i]) {
int v = ver[i];
if (v == fu || ban[v]) continue;
plus_bg(v, u, p, curL + 1);
}
}
void solve(int u) {
if (A_sz <= Bn) {
work_sm(u, 0);
ban[u] = 1;
return;
} else {
work_bg(u, 0, 1);
for (int i = head[u]; i; i = Next[i]) {
int v = ver[i];
if (ban[v]) continue;
temp_sz[v] = Get_sz(v, u);
if (temp_sz[v] <= Bn) {
A_son = v;
plus_bg(v, u, A.t[1].trans[a[u] - 'a'], 1);
} else {
work_bg(v, u, -1);
}
}
ban[u] = 1;
}
for (int i = head[u]; i; i = Next[i]) {
int v = ver[i];
if (ban[v]) continue;
A_sz = temp_sz[v], A_rt = 0;
Get_rt(v, 0), solve(A_rt);
}
}
int main() {
read(n), read(m), Bn = 1200;
for (int i = 1, u, v; i < n; i ++) {
read(u), read(v);
add_edge(u, v), add_edge(v, u);
}
scanf("%s", a + 1);
scanf("%s", tex + 1);
for (int i = 1; i <= m; i ++) posA[i] = A.extend(tex[i] - 'a');
A.build();
for (int i = m; i >= 1; i --) posB[i] = B.extend(tex[i] - 'a');
B.build();
A_sz = n, A_rt = 0;
Get_rt(1, 0), solve(A_rt);
printf("%lld\n", ans);
return 0;
}