考研路茫茫――单词情结

AC自动机+矩阵快速幂好题
题目描述

发现 L 很大,考虑矩阵快速幂,先考虑朴素,求出答案的补集,设 \(f_{i,j}\)\(i\) 表示当前在单词的哪一位上,\(j\) 表示在 Fail 树上的哪一位上,不接触有单词的地方(非得用答案补集的原因是直接考虑答案太麻烦,得要高深的容斥),可得转移式 if (!t[t[j].son[c]].end) f[i + 1][t[j].son[c]] += f[i][j];

最终答案即为:\(Ans=\sum26^k-\sum\limits_{0\le i\le L,0\le j\le cnt} f_{i,j}\)

前者可以用动态规划求出 \(g_{i}=26g_{i-1}+1\)

那么可以有这个代码:

#include<bits/stdc++.h>
#define LL long long
//#define int LL
#define per(i, a, b) for (int i = a, END##i = b; i >= END##i; i--)
#define rep(i, a, b) for (int i = a, END##i = b; i <= END##i; i++)
#define repn(x) rep(x, 1, n)
#define repm(x) rep(x, 1, m)
#define pb push_back
#define PII pair<int, int>
#define i64 unsigned long long
#define YY puts("Yes"), exit(0)
#define NN puts("No"), exit(0)
using namespace std;
const int Mod = 1e9 + 7;
const int Inf = 0x3f3f3f3f;
const LL InfLL = 0x3f3f3f3f3f3f3f3f;
inline LL read() {LL s = 0, fu = 1; char ch = getchar(); while (ch < '0' || ch > '9') ch == '-' ? fu = -1 : 0, ch = getchar(); while (ch >= '0' && ch <= '9') s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar(); return s * fu;}

const int N = 110;
int n, L;
struct Node {
    int son[26], fail;
    bool end;
}t[N]; int cnt;
char str[N];
void Insert(char *str) {
    int now = 0;
    for (int i = 0; str[i]; i++) {
        int w = str[i] - 'a';
        if (!t[now].son[w]) t[now].son[w] = ++cnt;
        now = t[now].son[w];
    }
    t[now].end = 1;
}
void getFail() {
    queue<int> q;
    rep(i, 0, 25) if (t[0].son[i])
        q.push(t[0].son[i]);
    while (!q.empty()) {
        int now = q.front();
        q.pop(); t[now].end |= t[t[now].fail].end;
        rep(i, 0, 25) {
            if (t[now].son[i]) {
                t[t[now].son[i]].fail = t[t[now].fail].son[i];
                q.push(t[now].son[i]);
            } else t[now].son[i] = t[t[now].fail].son[i];
        }
    }
}
i64 f[N][N], g[N];
i64 qpow(i64 a, int b) {
    i64 ans = 1;
    while (b) {
        if (b & 1) ans = a * ans;
        a = a * a;
        b >>= 1;
    }
    return ans;
}

inline void Main() {
    while (scanf("%d%d", &n, &L) != -1) {
        cnt = 0; memset(t, 0, sizeof(t));
        while (n--) {
            scanf("%s", str);
            Insert(str);
        }
        getFail();
        memset(f, 0, sizeof(f));
        f[0][0] = 1;
        rep(i, 0, L - 1) rep(j, 0, cnt) rep(c, 0, 25)
            if (!t[t[j].son[c]].end) f[i + 1][t[j].son[c]] += f[i][j];
        g[0] = 1; i64 ans = 0;
        rep(i, 1, L) g[i] = 26 * g[i - 1] + 1;
        rep(i, 0, L) rep(j, 0, cnt)
            ans += f[i][j];
        // cout << g[1] << " " << g[2] << " " << g[3] << "\n";
        cout << g[L] - ans << "\n";
    }
}

signed main() {
    // freopen("input.in", "r", stdin);
    int T = 1;
    while (T--)
        Main();
    return 0;
}

考虑优化,首先优化简单的部分,即 \(\sum26^k\) 可得矩阵:

\[\begin{bmatrix} G(n) \\ 1 \end{bmatrix} = \begin{bmatrix} 26 & 1 \\ 0 & 1 \end{bmatrix} \times \begin{bmatrix} G(n-1) \\ 1 \end{bmatrix} \]

观察转移式 if (!t[t[j].son[c]].end) f[i + 1][t[j].son[c]] += f[i][j]; ,可以发现可以转化成在 Fail 树上连边(可以连的),然后求值,也就是可以将这个转移过程改成,从 Fail 树的根节点往下转移,重复 L 次,且符合条件,t[j].son[c]j 的路径有多少个可以转化成 \(g_{i,j}=\sum g_{i,k}g_{k,j}\) 可以发现是矩阵快速幂的公式。

\(\mathscr{Code:}\)

#include <iostream>
#include <cstdio>
#include <queue>
#include <cstring>
#define LL long long
//#define int LL
#define per(i, a, b) for (int i = a, END##i = b; i >= END##i; i--)
#define rep(i, a, b) for (int i = a, END##i = b; i <= END##i; i++)
#define repn(x) rep(x, 1, n)
#define repm(x) rep(x, 1, m)
#define pb push_back
#define PII pair<int, int>
#define i64 unsigned long long
#define YY puts("Yes"), exit(0)
#define NN puts("No"), exit(0)
using namespace std;
const int Mod = 1e9 + 7;
const int Inf = 0x3f3f3f3f;
const LL InfLL = 0x3f3f3f3f3f3f3f3f;
inline LL read() {LL s = 0, fu = 1; char ch = getchar(); while (ch < '0' || ch > '9') ch == '-' ? fu = -1 : 0, ch = getchar(); while (ch >= '0' && ch <= '9') s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar(); return s * fu;}

const int N = 50;
int n;
i64 L;
struct Node {
    int son[26], fail;
    bool end;
}t[N]; int cnt;
char str[N];
void Insert(char *str) {
    int now = 0;
    for (int i = 0; str[i]; i++) {
        int w = str[i] - 'a';
        if (!t[now].son[w]) t[now].son[w] = cnt++;
        now = t[now].son[w];
    }
    t[now].end = 1;
}
void getFail() {
    queue<int> q;
    rep(i, 0, 25) if (t[0].son[i])
        q.push(t[0].son[i]);
    while (!q.empty()) {
        int now = q.front();
        q.pop(); t[now].end |= t[t[now].fail].end;
        rep(i, 0, 25) {
            if (t[now].son[i]) {
                t[t[now].son[i]].fail = t[t[now].fail].son[i];
                q.push(t[now].son[i]);
            } else t[now].son[i] = t[t[now].fail].son[i];
        }
    }
}
struct Matrix {
    i64 M[N][N], n;
    Matrix() {};
    Matrix(int _) {
        n = _;
        memset(M, 0, sizeof(M));
    }
    Matrix operator* (const Matrix& b) const {
        Matrix res = Matrix(n);
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                for (int k = 0; k < n; k++)
                    res.M[i][j] += M[i][k] * b.M[k][j];
        return res;
    } 
};
Matrix qpow(Matrix a, i64 b) {
    Matrix res = Matrix(a.n);
    rep(i, 0, a.n - 1) res.M[i][i] = 1;
    repn(i) res.M[i][i] = 1;
    while (b) {
        if (b & 1) res = a * res;
        a = a * a;
        b >>= 1;
    }
    return res;
}

inline void Main() {
    while (cin >> n >> L) {
        cnt = 1; memset(t, 0, sizeof(t));
        while (n--) {
            scanf("%s", str);
            Insert(str);
        }
        getFail();
        Matrix M1 = Matrix(2);
        M1.M[0][0] = 26, M1.M[1][1] = M1.M[0][1] = 1;
        M1 = qpow(M1, L);
        i64 ans = M1.M[0][0] + M1.M[0][1];

        Matrix M2 = Matrix(cnt + 1);
        rep(i, 0, cnt - 1) rep(j, 0, 25) {
            if (t[t[i].son[j]].end) continue;
            M2.M[i][t[i].son[j]]++;
        }
        for (int i = 0; i < cnt + 1; i++)
            M2.M[i][cnt] = 1;
        M2 = qpow(M2, L);

        rep(i, 0, M2.n - 1) ans -= M2.M[0][i];
        cout << ans << "\n";
    }
}

signed main() {
    // freopen("input.in", "r", stdin);
    int T = 1;
    while (T--)
        Main();
    return 0;
}
posted @ 2025-12-13 21:20  wh2011  阅读(1)  评论(0)    收藏  举报