UESTCPC 2025 决赛补题
I. 双端魔咒
题意
给出大小分别为 \(l\) 和 \(r\) 的两个字符串数组 \(pre\) 和 \(suf\) ,以及一个字符串 \(s\) ,问 \(s\) 中有多少个子串,满足:它的至少一个前缀包含在 \(pre\) 中,且它的至少一个后缀包含在 \(suf\) 中。
\(s\) 的长度不超过1e6,\(pre\) 和 \(suf\) 各自的总长度不超过1e6。
解法
1. 模式串去重
问题的核心就是多模式串匹配,看完不难想到是AC自动机;也不难想到,如果构造一组数据, \(pre\) 和 \(suf\) 都分别是 \(a, aa, aaa, ...\) ,而询问的串(后文中记为 \(s\) )为1e6个 \(a\) ,时间复杂度几乎无论如何都会退化为 \(O(n^2)\) ,所以必须去掉一部分 \(pre\) 和 \(suf\) 。例如,如果 \(pre_i\) 是 \(pre_j\) 的前缀,那就可以当 \(pre_j\) 不存在,这是容易想到的。
可以用字典树实现。比如对于 \(pre\) ,把所有 \(pre\) 插入一个字典树中,然后对字典树 dfs ,当走到某个 \(pre\) 的终止节点,就记录这个字符串是有效的,然后砍去当前节点的所有后续边,直接 return 。
2. 获得模式串的起始位置
将有效的 \(pre\) 全部倒序插入字典树,然后构建 fail 指针(建AC自动机)。将 \(s\) 倒序地在AC自动机上匹配,获得所有 \(pre\) 的起始位置。由于需要知道 \(s\) 中匹配到的位置,又不能在匹配过程中跳 fail ,可以在匹配之前在 fail 树中 pushdown 模式串结束标记。因为 \(pre\) 之间没有前缀关系,所以结束标记不会发生覆盖。
同理,对所有 \(suf\) 倒序建字典树从而去后缀化之后,正序插入另一个字典树然后构建 fail , \(s\) 在上面正序匹配,每当匹配到某个 \(suf\) 的终止位置,就减去它的长度获得起始位置。
3. 统计答案
对 \(s\) 中每个作为 \(pre\) 起点的位置,看有多少个 \(suf\) 的起始位置大于等于它(用后缀和实现),加到答案上。
此时统计的答案包含了一些错误的情况,例如 \(suf_j\) 确实不先于 \(pre_i\) 开始,但它却在 \(pre_i\) 之前结束。于是,对于每个有效的 \(pre\) ,统计:将它去掉最后一个字符之后,它包含了多少个有效的 \(suf\) (也就是此时在字典树中的 \(suf\) ,所以直接用 \(pre\) 跑AC自动机就行)。
答案减去每个有效 \(pre\) 的 出现次数 × 包含有效 \(suf\) 的数量 ,就做完了。
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
#include <vector>
#define int long long
#define ld long double
const int N = 1e6 + 5;
int l, r;
std::string pre[N], suf[N], s;
int nxt[N][26], tot, isend[N], fail[N];
bool usepre[N], usesuf[N];
int prestart[N], sufstart[N], cntpre[N];
std::vector<int> ve[N];
void _insert(std::string t, int id) {
int cur = 0;
for (auto ch : t) {
int tmp = ch - 'a';
if (!nxt[cur][tmp]) nxt[cur][tmp] = ++tot;
cur = nxt[cur][tmp];
}
isend[cur] = id;
}
void _predicdfs(int cur) {
if (isend[cur]) {
usepre[isend[cur]] = 1;
for (int i = 0; i < 26; i++) nxt[cur][i] = 0;
return;
}
for (int i = 0; i < 26; i++) {
if (nxt[cur][i]) _predicdfs(nxt[cur][i]);
}
}
void _sufdicdfs(int cur) {
if (isend[cur]) {
usesuf[isend[cur]] = 1;
for (int i = 0; i < 26; i++) nxt[cur][i] = 0;
return;
}
for (int i = 0; i < 26; i++) {
if (nxt[cur][i]) _sufdicdfs(nxt[cur][i]);
}
}
void _getfail() {
std::queue<int> q;
for (int i = 0; i < 26; i++) {
if (nxt[0][i]) q.push(nxt[0][i]);
}
int cur;
while (!q.empty()) {
cur = q.front();
q.pop();
for (int i = 0; i < 26; i++) {
if (nxt[cur][i]) {
fail[nxt[cur][i]] = nxt[fail[cur]][i];
q.push(nxt[cur][i]);
} else
nxt[cur][i] = nxt[fail[cur]][i];
}
}
for (int i = 0; i <= tot; i++) ve[i].clear();
for (int i = 1; i <= tot; i++) ve[fail[i]].push_back(i);
}
void _isend_pushdown(int cur) {
if (isend[cur]) {
for (auto to : ve[cur]) isend[to] = isend[cur];
}
for (auto to : ve[cur]) _isend_pushdown(to);
}
void _getpre() {
_isend_pushdown(0);
int cur = 0;
for (int i = s.length() - 1; i >= 0; i--) {
cur = nxt[cur][s[i] - 'a'];
if (isend[cur]) {
prestart[i] = 1;
cntpre[isend[cur]]++;
}
}
}
void _getsuf() {
_isend_pushdown(0);
int cur = 0, len = s.length();
for (int i = 0; i < len; i++) {
cur = nxt[cur][s[i] - 'a'];
if (isend[cur]) sufstart[i - suf[isend[cur]].length() + 1]++;
}
}
int _check(std::string t) {
int cur = 0, lmt = t.length() - 1, res = 0;
for (int i = 0; i < lmt; i++) {
cur = nxt[cur][t[i] - 'a'];
if (isend[cur]) res++;
}
return res;
}
inline void _init() {
memset(nxt, 0, sizeof(nxt));
memset(isend, 0, sizeof(isend));
memset(fail, 0, sizeof(fail));
tot = 0;
}
void solve() {
std::cin >> l >> r;
for (int i = 1; i <= l; i++) std::cin >> pre[i];
for (int i = 1; i <= r; i++) std::cin >> suf[i];
std::cin >> s;
for (int i = 1; i <= l; i++) _insert(pre[i], i);
_predicdfs(0);
_init();
for (int i = 1; i <= l; i++) {
if (usepre[i]) {
reverse(pre[i].begin(), pre[i].end());
_insert(pre[i], i);
}
}
_getfail();
_getpre();
_init();
for (int i = 1; i <= r; i++) {
reverse(suf[i].begin(), suf[i].end());
_insert(suf[i], i);
}
_sufdicdfs(0);
_init();
for (int i = 1; i <= r; i++) {
if (usesuf[i]) {
reverse(suf[i].begin(), suf[i].end());
_insert(suf[i], i);
}
}
_getfail();
_getsuf();
int len = s.length(), ans = 0;
for (int i = len - 2; i >= 0; i--) sufstart[i] += sufstart[i + 1];
for (int i = 0; i < len; i++) {
if (prestart[i]) ans += sufstart[i];
}
for (int i = 1; i <= l; i++) {
if (usepre[i]) {
reverse(pre[i].begin(), pre[i].end());
ans -= _check(pre[i]) * cntpre[i];
}
}
std::cout << ans << "\n";
}
signed main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
int ti = 1;
// std::cin >> ti;
while (ti--) solve();
return 0;
}
浙公网安备 33010602011771号