#4730. 匹配

题目描述

树大小为 $n$, 第 $i$ 边有字符集 $S_i$. 给定 $m$ 个模式串 $t_1,t_2,\dots,t_m$。

$Q$ 次询问 $(u,v)$, 设 $u\to v$ 经过的边为 $e_1,e_2,\dots,e_k$,求串 $s$ 的方案数,满足:
- $|s|=k$
- $\forall i\in[1,k],s_i\in S_{e_i}$
- $\exist j,t_j \text{ is a substring of }s$

题解

考虑用总方案数减去不合法的方案数,建立 $\text{AC}$ 自动机,即 $\text{dp}$ : $f[i][j]$ 表示前 $i$ 个字符,目前在自动机上的 $j$ 号节点上的方案数,考虑每一步都不能走到结束节点上。然后发现可以写成矩阵的形式,考场上写的分块做法过不去似乎也优化不了,于是我们可以预处理 $nlogn$ 个矩阵,然后每次就往上跳,用向量乘上矩阵即可,这样效率是 $O(40^3nlogn+40^2Qlogn)$ ,过不去,瓶颈在于预处理部分。

考虑另类的跳 $\text{lca}$ 的方式:每次跳 $\le lowbit(dp[x])$ 步,这样跳的次数也不会超过 $O(logn)$ 而且对于每个点只需要预处理 $log(lowbit(dp[x]))+1$ 个矩阵即可。

于是我们对于 $\text{deep}$ 的每一位,如果第 $i$ 位上 $1$ 的个数比 $0$ 的来的少的话,我们就全体的 $\text{deep}$ 加上 $2^i$ 即可,这样预处理的最多次数就是 $\frac{n}{2} \times 1+\frac{n}{4} \times 2+ \frac{n}{8} \times 3+...<2n$ ,这样效率就是 $O(40^3n+40^2Qlogn)$ 。

代码

#include <bits/stdc++.h>
using namespace std;
const int N=5005,M=42,P=998244353;
int n,m,q,hd[N],V[N],nx[N],t=1,e[M],tr[M][M],fi[M],fa[N][13],dp[N],Y,Z,f[2][M],su[2],c[N],b[N];
char g[N][M],h[M];queue<int>qu;bool F[N];
struct O{
    int p[M][M];
}d[M],a[N],G,up[2501][13],dn[2501][13],S[100];
void add(int u,int v){
    nx[++t]=hd[u];V[hd[u]=t]=v;
}
int X(int x){return x>=P?x-P:x;}
void ins(){
    int p=0,l=strlen(h);
    for (int k,i=0;i<l;i++){
        k=h[i]-'a';
        if (!tr[p][k])
            tr[p][k]=++t;
        p=tr[p][k];
    }
    e[p]=1;
}
void build(){
    for (int i=0;i<26;i++)
        if (tr[0][i]) qu.push(tr[0][i]);
    while(!qu.empty()){
        int k=qu.front();qu.pop();
        for (int i=0;i<26;i++)
            if (tr[k][i])
                fi[tr[k][i]]=tr[fi[k]][i],
                e[tr[k][i]]=e[tr[k][i]]|e[fi[tr[k][i]]],
                qu.push(tr[k][i]);
            else tr[k][i]=tr[fi[k]][i];
    }
}
O Add(O A,O B){
    for (int i=0;i<=t;i++)
        for (int j=0;j<=t;j++)
            A.p[i][j]=X(A.p[i][j]+B.p[i][j]);
    return A;
}
O Mul(O A,O B){
    for (int i=0;i<=t;i++)
        for (int j=0;j<=t;j++) G.p[i][j]=0;
    for (int k=0;k<=t;k++)
        for (int j=0;j<=t;j++) if (A.p[k][j])
            for (int i=0;i<=t;i++) if (B.p[i][k])
                G.p[i][j]=X(G.p[i][j]+1ll*A.p[k][j]*B.p[i][k]%P);
    return G;
}
void dfs(int u,int fr){
    dp[u]=dp[fa[u][0]=fr]+1;
    for (int v,j,i=hd[u];i;i=nx[i]){
        if ((v=V[i])==fr) continue;
        j=i>>1;b[v]=strlen(g[j]);
        for (int k=0;k<b[v];k++)
            a[v]=Add(a[v],d[g[j][k]-97]);
        dfs(v,u);
    }
}
void dfs(int u){
    for (int v,i=hd[u],w;i;i=nx[i]){
        if ((v=V[i])==fa[u][0]) continue;
        w=c[dp[v]&-dp[v]];
        up[v][0]=dn[v][0]=a[v];
        for (int j=1;j<=w;j++)
            up[v][j]=Mul(up[v][j-1],up[fa[v][j-1]][j-1]),
            dn[v][j]=Mul(dn[fa[v][j-1]][j-1],dn[v][j-1]);
        dfs(v);
    }
}
int lca(int u,int v,int &w){
    if (dp[u]<dp[v]) swap(u,v);
    while(dp[u]>dp[v]) w=1ll*w*b[u]%P,u=fa[u][0];
    while(u!=v) w=1ll*w*b[u]%P,
        w=1ll*w*b[v]%P,u=fa[u][0],v=fa[v][0];
    return u;
}
void Dp(O A){
    for (int i=0;i<=t;i++){
        f[Y][i]=0;
        for (int j=0;j<=t;j++) if (A.p[i][j])
            f[Y][i]=X(f[Y][i]+1ll*A.p[i][j]*f[Z][j]%P);
    }
    Z^=1;Y^=1;
}
int main(){
    cin>>n>>m>>q;
    for (int u,v,i=1;i<n;i++)
        scanf("%d%d%s",&u,&v,g[i]),
        add(u,v),add(v,u);t=0;
    for (int i=1;i<=m;i++)
        scanf("%s",h),ins();build();
    for (int i=0;i<26;i++)
        for (int k,j=0;j<=t;j++)
            if (!e[k=tr[j][i]]) d[i].p[k][j]++;
    dfs(1,0);fa[1][0]=1;c[1<<12]=12;
    for (int i=0;i<12;i++){
        su[0]=su[1]=0;
        for (int j=1;j<=n;j++)
            su[(dp[j]>>i)&1]++;
        if (su[0]>su[1]){
            for (int j=1;j<=n;j++)
                dp[j]+=(1<<i);
        }
        c[1<<i]=i;
    }
    for (int i=0;i<13;i++)
        for (int j=0;j<=t;j++)
            up[1][i].p[j][j]=dn[1][i].p[j][j]=1;
    for (int i=1;i<13;i++)
        for (int j=1;j<=n;j++)
            fa[j][i]=fa[fa[j][i-1]][i-1];
    dfs(1);
    for (int x,y,w,u,v,z,o;q--;){
        scanf("%d%d",&x,&y);
        Y=f[o=Z=0][0]=w=1;
        z=lca(x,y,w);
        while(x!=z){
            u=dp[x]&-dp[x];
            for (int i=u;i;i>>=1)
                if (dp[x]-i>=dp[z]){
                    S[++o]=up[x][c[i]];
                    x=fa[x][c[i]];break;
                }
        }
        v=o;
        while(y!=z){
            u=dp[y]&-dp[y];
            for (int i=u;i;i>>=1)
                if (dp[y]-i>=dp[z]){
                    S[++o]=dn[y][c[i]];
                    y=fa[y][c[i]];break;
                }
        }
        for (int i=1;i<=v;i++) Dp(S[i]);
        for (int i=o;i>v;i--) Dp(S[i]);
        for (int i=0;i<=t;i++)
            w=X(w+P-f[Z][i]),f[0][i]=f[1][i]=0;
        printf("%d\n",w);
    }
    return 0;
}

 

posted @ 2020-02-21 23:36  xjqxjq  阅读(164)  评论(0编辑  收藏  举报