poj2778 求构造长度为n的字符串不包含给定的m个字符串的个数(矩阵乘法+ac自动机)

题:http://poj.org/problem?id=2778

题意:给定m个模式串,问长度为n的字符串不包含这些模式串的有几种可能

分析:因为n很大,所以考虑矩阵ksm来解决,构造一个矩阵res[i][j]表示从i到j有多少种方案数,我们先考虑只走1步后的res数组的构造,i节点能走到j节点当且仅当i节点和j节点都是安全的点,这个安全的点就是用m个模式串构成的trie树上的end[],显然根结点是安全结点。 一个非根结点是危险结点的充要条件是: 它的路径字符串本身就是一个不良单词 ,或 它的路径字符串的后缀对应的结点(即fail[i])是危险结点。预处理完ac自动机后,就可以处理res数组,这个res数组就相当于在学离散数学时的矩阵;剩下的n步就交给ksm这个res数组即可,答案就是sum(res[0][i])

#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<queue>
#include<cmath>
using namespace std;
typedef long long ll;
const int mod=1e5;
const int maxn=4;///只有4个字母
const int N=1e3+3;
struct ac{
    int trie[N][maxn],fail[N];
    int tot,root;
    bool end[N];
    ll res[N][N],ans[N][N],tmp[N][N];
    int newnode(){
        for(int i=0;i<maxn;i++){
            trie[tot][i]=-1;
        }
        end[tot++]=0;
        return tot-1;
    }
    void init(){
        tot=0;
        root=newnode();
        memset(res,0,sizeof(res));
        memset(ans,0,sizeof(ans));
        memset(end,false,sizeof(end));
    }
    int getid(char c){
        if(c=='A')
            return 0;
        if(c=='C')
            return 1;
        if(c=='T')
            return 2;
        if(c=='G')
            return 3;
    }
    void insert(char *buf,int id){
        int now=root,len=strlen(buf);
        for(int i=0;i<len;i++){
            int x=getid(buf[i]);
            if(trie[now][x]==-1)
                trie[now][x]=newnode();
            now=trie[now][x];
        }
        end[now]=true;//它的路径字符串本身就是一个不良单词 
    }
    void getfail(){
    
        queue<int>que;
        while(!que.empty())
            que.pop();
        fail[root]=root;
        for(int i=0;i<maxn;i++)
            if(trie[root][i]==-1)
                trie[root][i]=root;
            else{
                fail[trie[root][i]]=root;
                que.push(trie[root][i]);
            }
        while(!que.empty()){
            int now=que.front();
            que.pop();
            if(end[fail[now]])//它的路径字符串的后缀对应的结点(即fail[i])是危险结点 
                end[now]=true;
            for(int i=0;i<maxn;i++){
                if(trie[now][i]!=-1){
                    fail[trie[now][i]]=trie[fail[now]][i];
                    que.push(trie[now][i]);
                }
                else
                    trie[now][i]=trie[fail[now]][i];
            }
        }
    }
    
    void path(){
        for(int i=0;i<tot;i++){
            for(int j=0;j<maxn;j++)
                if(!end[i]&&!end[trie[i][j]]){
                //    cout<<i<<"!!"<<j<<endl;
                    res[i][trie[i][j]]++;
                }
                    
        }
    }
    void mul(ll a[][N],ll b[][N]){
        for(int i=0;i<tot;i++)
            for(int j=0;j<tot;j++){
                tmp[i][j]=0;
                for(int k=0;k<tot;k++)
                    tmp[i][j]=(tmp[i][j]+a[i][k]*b[k][j])%mod;
            }
        for(int i=0;i<tot;i++)
            for(int j=0;j<tot;j++)
                a[i][j]=tmp[i][j];
    }
    ll solve(int n){
        
        for(int i=0;i<tot;i++)
            ans[i][i]=1;
        while(n){
            if(n&1){
                mul(ans,res);
            }
            n>>=1;
            mul(res,res);
        }
        ll ANS=0;
        for(int i=0;i<tot;i++)
            ANS=(ANS+ans[0][i])%mod;
        return ANS;
    }
}AC;
char s[110];
int main(){
    int m,n;
    scanf("%d%d",&m,&n);
    AC.init();
    for(int i=1;i<=m;i++){
        scanf("%s",s);
        AC.insert(s,i);
    }
    AC.getfail();
    AC.path();
    printf("%lld\n",AC.solve(n));
    return 0;
}
View Code

 

posted @ 2020-02-17 16:01  starve_to_death  阅读(331)  评论(0编辑  收藏  举报