加载中...

ABC 419 F(AC自动机+dp)

F

初见杀的一道好题,头一次遇到将字典树结点作为 \(dp\) 的其中一维状态的题目。

状态定义:\(dp[i][state][last]\): 考虑前 \(i\) 个字符,拥有的字符串集合的二进制表示为 \(state\),且当前末尾的字符串前缀为 \(last\)\(last\) 是字典树中的某个结点编号,表示的一定是模式串集合中的某个前缀。特殊地,\(last==0\) 时表示空串),方案数。

\(ans = dp[L][(1<<n)-1][0 \backsim sum]\)
\(init: dp[0][0][0] = 1\)

转移过程见代码即可。

这里特别讲一下 \(ACAM\) 中的 \(id\) 数组:代码中的 \(id[u]\)字典树初始化后 (未 \(build\)) 的意思是:结点 \(u\) 表示的前缀 包含原模式串集合的二进制状态。而实际上,在当前构建的字符串末尾添加一个新字符时,新产生的 在模式串集合中 的字符串应当考虑所有以该新字符结尾的后缀。因此在 \(dp\) 状态转移过程中,包含模式串集合的二进制状态 那一维的转移若想正确,则 \(id[u]\) 实际上应该表示的是 结点 \(u\) 对应前缀字符串中的所有后缀在模式串集合中的二进制状态表示。而 \(fail[u]\) 链上的所有结点恰好可以表示这一点。因此只需要在 \(ac\) 自动机的 \(build\) 过程中,对每个结点的 \(id[u]\)\(fail\) 链上的前缀 \(or\),即可得到 能正确实现 \(dp\) 转移 的 \(id\) 数组。

// 最朴素的做法(未改变词频表ch[u][j],需要fail指针一直绕圈,复杂度可能很高,但还是能过,并且理解起来更直观)。
#include <bits/stdc++.h>
// #define int long long 
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#define fr first
#define se second
#define endl '\n'
#define pb push_back
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
mt19937_64 rnd(time(0));

const int N = 2e5 + 10;

struct ACAM
{
    vector<vector<int>> ch; // 字典树表,注意在执行build()函数后会改变

    vector<int> fail; // 每个结点均有一个fail指针,指向当前结点表示的字符串匹配失败后,下一步需要找的最长可匹配的后缀对应的终止结点

    vector<int> id;  // id[p]: 字典树结点编号 -> 模式串编号 的映射
    int idx;         // 字典树中实际结点数量


    void init(int n) // n为所有模式串的字符数量之和
    {
        ch.resize(n + 1);
        ch[0].resize(26, 0);
        fail.resize(n + 1, -1);
        id.resize(n + 1, 0);
        idx = 0;
    }

    void insert(string& str, int x) // 将所有模式串插入,形成的只是普通的字典树
    {
        int p = 0;
        for (int i = 0; i < str.size(); i++){
            int j = str[i] - 'a';
            if (!ch[p][j]){
                ch[p][j] = ++idx;
                ch[idx].resize(26, 0);
            }
            p = ch[p][j];
        }
        // cnt[p]++;
        id[p] |= (1 << x);
    }

    void build() // 在普通字典树的基础上建立AC自动机 -> 添加fail指针
    {
        queue<int> q; // 进入队列的点都是已知fail指针的点,每个点最多只进入一次队列
        for (int i = 0; i < 26; i++){ // 先将根结点0的所有儿子加入队列(它们的fail指针均为0)
            if (ch[0][i]){
                fail[ch[0][i]] = 0;
                q.push(ch[0][i]);
            }
        }
        while (q.size()){
            int u = q.front();
            q.pop();
            id[u] |= id[fail[u]];
            for (int j = 0; j < 26; j++){
                int v = ch[u][j];
                if(v){ // ! 注意,在不改变词频表的情况下,建立fail指针需要一直绕圈
                    int f = fail[u];
                    while(f != -1 && !ch[f][j]){
                        f = fail[f];
                    }
                    if(f == -1){
                        fail[v] = 0;
                    }
                    else{
                        fail[v] = ch[f][j];
                    }
                    q.push(v);
                }
            }
        }
    }
};
const int MOD = 998244353;

int n, L;
string s[10];
int dp[105][1<<8][85]; 

void solve()
{
    cin >> n >> L;
    ACAM ac;
    int tot_len = 0;
    for(int i = 0; i < n; i ++){
        cin >> s[i];
        tot_len += s[i].length();
    }

    ac.init(tot_len);
    for(int i = 0; i < n; i ++){
        ac.insert(s[i], i);
    }
    ac.build();

    int node_sum = ac.idx;
    dp[0][0][0] = 1;
    for(int i = 0; i < L; i ++){
        for(int state = 0; state < (1 << n); state ++){
            for(int u = 0; u <= node_sum; u ++){ // 枚举前缀
                vector<bool> vis(26, false);
                for(int j = u; j != -1; j = ac.fail[j]){
                    for(int jj = 0; jj < 26; jj ++){
                        int v = ac.ch[j][jj];
                        if(v && !vis[jj]){
                            vis[jj] = true;
                            int n_state = state | ac.id[v];
                            dp[i + 1][n_state][v] = (0ll + dp[i + 1][n_state][v] + dp[i][state][u]) % MOD;
                        }
                    }
                }
                for(int jj = 0; jj < 26; jj ++){
                    if(!vis[jj]){
                        dp[i + 1][state][0] = (0ll + dp[i + 1][state][0] + dp[i][state][u]) % MOD;
                    }
                }
            }
        }
    }

    int sum = 0;
    for(int i = 0; i <= node_sum; i ++){
        sum = (0ll + sum + dp[L][(1<<n)-1][i]) % MOD;
    }
    cout << sum << endl;
}


signed main()
{
    ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr);
    // int T=1; cin>>T; while(T--)
    solve();
    return 0;
}
// 正常做法(在正常字典树基础上改变了ch[u][j](可以理解为并查集中的路径压缩),转移时也无需跳fail指针,直接从当前状态转移即可)
#include <bits/stdc++.h>
// #define int long long 
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#define fr first
#define se second
#define endl '\n'
#define pb push_back
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
mt19937_64 rnd(time(0));

const int N = 2e5 + 10;

struct ACAM
{
    vector<vector<int>> ch; // 字典树表,注意在执行build()函数后会改变

    vector<int> fail; // 每个结点均有一个fail指针,指向当前结点表示的字符串匹配失败后,下一步需要找的最长可匹配的后缀对应的终止结点

    vector<int> id;  // id[p]: 字典树结点编号 -> 模式串编号 的映射
    int idx;         // 字典树中实际结点数量


    void init(int n) // n为所有模式串的字符数量之和
    {
        ch.resize(n + 1);
        ch[0].resize(26, 0);
        fail.resize(n + 1, 0);
        id.resize(n + 1, 0);
        idx = 0;
    }

    void insert(string& str, int x) // 将所有模式串插入,形成的只是普通的字典树
    {
        int p = 0;
        for (int i = 0; i < str.size(); i++){
            int j = str[i] - 'a';
            if (!ch[p][j]){
                ch[p][j] = ++idx;
                ch[idx].resize(26, 0);
            }
            p = ch[p][j];
        }
        id[p] |= (1 << x);
    }

    void build() // 在普通字典树的基础上建立AC自动机 -> 添加fail指针
    {
        queue<int> q; // 进入队列的点都是已知fail指针的点,每个点最多只进入一次队列
        for (int i = 0; i < 26; i++){ // 先将根结点0的所有儿子加入队列(它们的fail指针均为0)
            if (ch[0][i]){
                q.push(ch[0][i]);
            }
        }
        while (q.size()){
            int u = q.front();
            q.pop();
            id[u] |= id[fail[u]];
            for (int j = 0; j < 26; j++){
                int v = ch[u][j];
                if(v){
                    fail[v] = ch[fail[u]][j];
                    q.push(v);
                }
                else{
                    ch[u][j] = ch[fail[u]][j];
                }
            }
        }
    }
};
const int MOD = 998244353;

int n, L;
string s[10];
int dp[105][1<<8][85]; 

void solve()
{
    cin >> n >> L;
    ACAM ac;
    int tot_len = 0;
    for(int i = 0; i < n; i ++){
        cin >> s[i];
        tot_len += s[i].length();
    }

    ac.init(tot_len);
    for(int i = 0; i < n; i ++){
        ac.insert(s[i], i);
    }
    ac.build();

    int node_sum = ac.idx;
    dp[0][0][0] = 1;
    for(int i = 0; i < L; i ++){
        for(int state = 0; state < (1 << n); state ++){
            for(int u = 0; u <= node_sum; u ++){ // 枚举前缀
                for(int j = 0; j < 26; j ++){
                    int v = ac.ch[u][j];
                    int n_state = state | ac.id[v];
                    dp[i + 1][n_state][v] = (0ll + dp[i + 1][n_state][v] + dp[i][state][u]) % MOD;
                }
            }
        }
    }

    int sum = 0;
    for(int i = 0; i <= node_sum; i ++){
        sum = (0ll + sum + dp[L][(1<<n)-1][i]) % MOD;
    }
    cout << sum << endl;
}


signed main()
{
    ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr);
    // int T=1; cin>>T; while(T--)
    solve();
    return 0;
}
posted @ 2025-08-19 12:48  jxs123  阅读(12)  评论(0)    收藏  举报