BZOJ1030 [JSOI2007]文本生成器 (AC自动机+DP)

题目大意:求长度为$M$的用$'A'-'Z'$组成的字符串中至少含有给出的N个模式串之一的个数。

传送门

似乎不那么裸的AC自动机题目都是在自动机上跑DP。。。

直接求解不好算,我们可以用$26^M$减去不含模式串的个数。

 


建好自动机,设f[i][j]表示走i步,现在在j号节点的路径条数。

那么f[i][j]可以转移f[i+1][son[j][k]]。
就是i+1个字符为k的状态。
最后把所有f[m][i]累和就是不可读的串。


 

这是网上题解的说法,大概就是这么个意思。

我写的是指针版本的trie,比较丑,看上面的解释就好啦。

 1 #include<cstring>
 2 #include<cstdio>
 3 #include<cmath>
 4 #include<algorithm>
 5 #define foru(i,x,y) for(int i=x;i<=y;i++)
 6 using namespace std;
 7 const int mod=10007;
 8 struct trie{
 9     int v,f[201];
10     trie *nxt[26],*fail;
11     void init(){v=0;fail=NULL;foru(i,0,25)nxt[i]=NULL;memset(f,0,sizeof(f));}
12 }*r,*q[10000],*ind[10000];
13 
14 int n,m,ans,ans2,cnt;
15 char ch[200];
16 
17 void add(char *s){
18     trie *k=r,*p;
19     int l=strlen(s);
20     foru(i,0,l-1){
21         int id=s[i]-'A';
22         if(!k->nxt[id]){
23             p=(trie*)malloc(sizeof(trie));
24             p->init();
25             k->nxt[id]=p;
26             k=p;
27             ind[++cnt]=p;
28         }else{
29             k=k->nxt[id];
30         }
31     }
32     k->v=1;
33 }
34 
35 void setfail(){
36     trie *k=r,*p;
37     int s=1,t=0;
38     q[++t]=r;
39     while(s<=t){
40         k=q[s++];
41         foru(i,0,25){
42             if(k->nxt[i]){
43                 p=k->fail;
44                 while(p&&!p->nxt[i])p=p->fail;
45                 k->nxt[i]->fail=(p?p->nxt[i]:r);
46                 q[++t]=k->nxt[i];
47             }else
48                 k->nxt[i]=(k!=r?k->fail->nxt[i]:r);
49         }
50         if(k!=r)k->v|=k->fail->v;//如果一个串的后缀可以被理解,那么这个串也是可以被理解的 
51     }
52 }
53 
54 void work(){
55     r->f[0]=1;
56     foru(i,1,m){
57         foru(j,0,cnt){
58             if(ind[j]->v)continue;
59             foru(k,0,25){
60                 ind[j]->nxt[k]->f[i]+=ind[j]->f[i-1];
61                 ind[j]->nxt[k]->f[i]%=mod;
62             }
63         }
64     }
65     foru(i,0,cnt)
66         if(!ind[i]->v){
67             ans2+=ind[i]->f[m];
68             ans2%=mod;
69         }
70     ans=1;
71     foru(i,1,m)ans=(ans*26)%mod;
72     printf("%d\n",((ans-ans2)%mod+mod)%mod);
73 }
74 
75 int main(){
76     r=(trie*)malloc(sizeof(trie));
77     r->init();
78     ind[0]=r;
79     scanf("%d%d",&n,&m);
80     foru(i,1,n){
81         scanf("%s",ch);
82         add(ch);
83     }
84     setfail();
85     work();
86     return 0;
87 }

 

posted @ 2017-09-12 13:35  羊毛羊  阅读(140)  评论(0编辑  收藏