UESTCPC 2025 决赛补题

I. 双端魔咒

题意

给出大小分别为 \(l\)\(r\) 的两个字符串数组 \(pre\)\(suf\) ,以及一个字符串 \(s\) ,问 \(s\) 中有多少个子串,满足:它的至少一个前缀包含在 \(pre\) 中,且它的至少一个后缀包含在 \(suf\) 中。
\(s\) 的长度不超过1e6,\(pre\)\(suf\) 各自的总长度不超过1e6。

解法

1. 模式串去重

问题的核心就是多模式串匹配,看完不难想到是AC自动机;也不难想到,如果构造一组数据, \(pre\)\(suf\) 都分别是 \(a, aa, aaa, ...\) ,而询问的串(后文中记为 \(s\) )为1e6个 \(a\) ,时间复杂度几乎无论如何都会退化为 \(O(n^2)\) ,所以必须去掉一部分 \(pre\)\(suf\) 。例如,如果 \(pre_i\)\(pre_j\) 的前缀,那就可以当 \(pre_j\) 不存在,这是容易想到的。
可以用字典树实现。比如对于 \(pre\) ,把所有 \(pre\) 插入一个字典树中,然后对字典树 dfs ,当走到某个 \(pre\) 的终止节点,就记录这个字符串是有效的,然后砍去当前节点的所有后续边,直接 return 。

2. 获得模式串的起始位置

将有效的 \(pre\) 全部倒序插入字典树,然后构建 fail 指针(建AC自动机)。将 \(s\) 倒序地在AC自动机上匹配,获得所有 \(pre\) 的起始位置。由于需要知道 \(s\) 中匹配到的位置,又不能在匹配过程中跳 fail ,可以在匹配之前在 fail 树中 pushdown 模式串结束标记。因为 \(pre\) 之间没有前缀关系,所以结束标记不会发生覆盖。
同理,对所有 \(suf\) 倒序建字典树从而去后缀化之后,正序插入另一个字典树然后构建 fail , \(s\) 在上面正序匹配,每当匹配到某个 \(suf\) 的终止位置,就减去它的长度获得起始位置。

3. 统计答案

\(s\) 中每个作为 \(pre\) 起点的位置,看有多少个 \(suf\) 的起始位置大于等于它(用后缀和实现),加到答案上。
此时统计的答案包含了一些错误的情况,例如 \(suf_j\) 确实不先于 \(pre_i\) 开始,但它却在 \(pre_i\) 之前结束。于是,对于每个有效的 \(pre\) ,统计:将它去掉最后一个字符之后,它包含了多少个有效的 \(suf\) (也就是此时在字典树中的 \(suf\) ,所以直接用 \(pre\) 跑AC自动机就行)。
答案减去每个有效 \(pre\)出现次数 × 包含有效 \(suf\) 的数量 ,就做完了。

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
#include <vector>
#define int long long
#define ld long double
const int N = 1e6 + 5;

int l, r;
std::string pre[N], suf[N], s;

int nxt[N][26], tot, isend[N], fail[N];
bool usepre[N], usesuf[N];
int prestart[N], sufstart[N], cntpre[N];
std::vector<int> ve[N];

void _insert(std::string t, int id) {
    int cur = 0;
    for (auto ch : t) {
        int tmp = ch - 'a';
        if (!nxt[cur][tmp]) nxt[cur][tmp] = ++tot;
        cur = nxt[cur][tmp];
    }
    isend[cur] = id;
}

void _predicdfs(int cur) {
    if (isend[cur]) {
        usepre[isend[cur]] = 1;
        for (int i = 0; i < 26; i++) nxt[cur][i] = 0;
        return;
    }
    for (int i = 0; i < 26; i++) {
        if (nxt[cur][i]) _predicdfs(nxt[cur][i]);
    }
}

void _sufdicdfs(int cur) {
    if (isend[cur]) {
        usesuf[isend[cur]] = 1;
        for (int i = 0; i < 26; i++) nxt[cur][i] = 0;
        return;
    }
    for (int i = 0; i < 26; i++) {
        if (nxt[cur][i]) _sufdicdfs(nxt[cur][i]);
    }
}

void _getfail() {
    std::queue<int> q;
    for (int i = 0; i < 26; i++) {
        if (nxt[0][i]) q.push(nxt[0][i]);
    }
    int cur;
    while (!q.empty()) {
        cur = q.front();
        q.pop();
        for (int i = 0; i < 26; i++) {
            if (nxt[cur][i]) {
                fail[nxt[cur][i]] = nxt[fail[cur]][i];
                q.push(nxt[cur][i]);
            } else
                nxt[cur][i] = nxt[fail[cur]][i];
        }
    }
    for (int i = 0; i <= tot; i++) ve[i].clear();
    for (int i = 1; i <= tot; i++) ve[fail[i]].push_back(i);
}

void _isend_pushdown(int cur) {
    if (isend[cur]) {
        for (auto to : ve[cur]) isend[to] = isend[cur];
    }
    for (auto to : ve[cur]) _isend_pushdown(to);
}

void _getpre() {
    _isend_pushdown(0);
    int cur = 0;
    for (int i = s.length() - 1; i >= 0; i--) {
        cur = nxt[cur][s[i] - 'a'];
        if (isend[cur]) {
            prestart[i] = 1;
            cntpre[isend[cur]]++;
        }
    }
}

void _getsuf() {
    _isend_pushdown(0);
    int cur = 0, len = s.length();
    for (int i = 0; i < len; i++) {
        cur = nxt[cur][s[i] - 'a'];
        if (isend[cur]) sufstart[i - suf[isend[cur]].length() + 1]++;
    }
}

int _check(std::string t) {
    int cur = 0, lmt = t.length() - 1, res = 0;
    for (int i = 0; i < lmt; i++) {
        cur = nxt[cur][t[i] - 'a'];
        if (isend[cur]) res++;
    }
    return res;
}

inline void _init() {
    memset(nxt, 0, sizeof(nxt));
    memset(isend, 0, sizeof(isend));
    memset(fail, 0, sizeof(fail));
    tot = 0;
}
void solve() {
    std::cin >> l >> r;
    for (int i = 1; i <= l; i++) std::cin >> pre[i];
    for (int i = 1; i <= r; i++) std::cin >> suf[i];
    std::cin >> s;

    for (int i = 1; i <= l; i++) _insert(pre[i], i);
    _predicdfs(0);

    _init();
    for (int i = 1; i <= l; i++) {
        if (usepre[i]) {
            reverse(pre[i].begin(), pre[i].end());
            _insert(pre[i], i);
        }
    }
    _getfail();
    _getpre();

    _init();
    for (int i = 1; i <= r; i++) {
        reverse(suf[i].begin(), suf[i].end());
        _insert(suf[i], i);
    }
    _sufdicdfs(0);

    _init();
    for (int i = 1; i <= r; i++) {
        if (usesuf[i]) {
            reverse(suf[i].begin(), suf[i].end());
            _insert(suf[i], i);
        }
    }
    _getfail();
    _getsuf();

    int len = s.length(), ans = 0;
    for (int i = len - 2; i >= 0; i--) sufstart[i] += sufstart[i + 1];
    for (int i = 0; i < len; i++) {
        if (prestart[i]) ans += sufstart[i];
    }
    for (int i = 1; i <= l; i++) {
        if (usepre[i]) {
            reverse(pre[i].begin(), pre[i].end());
            ans -= _check(pre[i]) * cntpre[i];
        }
    }
    std::cout << ans << "\n";
}

signed main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);
    int ti = 1;
    // std::cin >> ti;
    while (ti--) solve();
    return 0;
}

posted on 2025-04-09 00:08  C12AK  阅读(15)  评论(1)    收藏  举报