Aho-Corasick automaton(AC自动机)解析及其在算法竞赛中的典型应用举例

摘要:

  本文主要讲述了AC自动机的基本思想和实现原理,如何构造AC自动机,着重讲解AC自动机在算法竞赛中的一些典型应用。

  • 什么是AC自动机?
  • 如何构造一个AC自动机?
  • AC自动机在算法竞赛中的典型应用有哪些?
  • 例题解析

什么是AC自动机?

  什么是AC自动机,不是自动AC的机器(想的美),而是一种多模匹配算法,英文名称Aho-Corasick automaton(前面的一串据说是一位科学家的名字),于1975年诞生于贝尔实验室。

  回忆之前的KMP算法解决的一类问题是给出一个模板和一个文本串,问这一个模板在该文本串中的存在情况(包括是否存在、存在几次、哪些位置等等)。现在如果是多个模板呢?可能你会想到一个一个拿出来用KMP算法进行匹配,但是如果文本串很长,模板又很多的话,KMP算法就不适合了(不满足于能解决问题,而追求又快又好的解决问题是算法研究的源动力)。而AC自动机正是为了解决这类问题而生的。

  基本思想

  不得不重提的是KMP算法之所以能够在高效的处理单模匹配问题,主要得益于next数组的建立,能够使匹配的状态在线性的字符串上进行转移,使得失配后副串能够尽可能的“滑的远一些“。而AC自动机也有类似功能的工具那就是fail指针。

  应该能想到的是单模匹配的KMP算法的状态转移图是线性的字符串加上失配边组成的,那么多模匹配的AC自动机算法的状态转移图是字典树加上失配边组成的。

  为了说明实际问题,直接看一个例子如下:

  

  问题很明确,我们需要只遍历一遍文本串就找出所有单词表中存在的单词(只遍历一遍的想法和KMP算法有异曲同工之妙)。

  我们先根据字符集合{she,he,say,shr,her}建立字典树如上图所示,然后我们拿着yasherhs去匹配,发现前两个字符无法匹配,跳过,第三个字符开始,she可以匹配,记录下来,继续往后走发现没有匹配了,结果就是该文本串只存在一个单词,很明显,答案是错的,因为存在she、he、her三个单词。

  可以发现的是使用文本串在字典树上进行匹配的时候,找到了一个单词结点后还应该看看有没有以该单词结点的后缀为前缀的其他单词,比如she的后缀he是单词he和her的前缀。因此就需要一个fail指针在发现失配的时候指向其他存在e的结点,来“安排”之后应该怎么办。

  总的来说,AC自动机中fail指针和KMP中next数组的作用是一致的,就是要想在只遍历一遍文本串的前提下,找到全部匹配模板,就必须安排好匹配过程中失配后怎么办。具体如何安排就是怎么在字典树上加失配边的问题了(也即如何构造一个AC自动机)。

如何构造一个AC自动机?

  字典树之前已经学过了(需要回顾的请点这里),关键是怎么加失配边。规则如下:

  • 根结点的fail指针为空(或者它自己);
  • 直接和根结点相连的结点,如果这些结点失配,就只能重新开始匹配,故它们的fail指针指向根结点;
  • 其他结点,设当前结点为father,其孩子结点为child。要寻找child的fail指针,需要看father的fail指针指向的结点,假设是tmp,要看tmp的孩子中有没有和child所代表的字符相同的,有则child的fail指针指向tmp的这个孩子结点,没有则继续沿着tmp的fail指针往上走,如果找到相同,就指向,如果一直找到了根结点的fail也就是空的时候,child的fail指针就指向root,表示重新从根结点开始匹配。

  其中考察father的fail指针指向的结点 有没有和child相同的结点,包括继续往上,就保证了前缀是相同的,比如刚才寻找右侧h的孩子结点e的fail指针时,找到右侧h的fail指针指向左侧的h结点,他的孩子中有e,就将右侧h的孩子e的fail指针指向它就保证了前缀h是相同的。

  这样,就用fail指针来安排好每次失配后应该跳到哪里,而fail指针跳到哪里,说明从根结点到这个结点之前的字符串已经匹配过了,从而避免了重复匹配,也就完美的解决了只遍历一次文本串就找出所有单词的问题

  具体编程实现在字典树上添加失配边有两种方法,一种是链表法,一种是转移矩阵法。

    链表法

  有了上面fail指针的计算规则,利用队列BFS顺序递推可以写出如下代码:

 1 const int maxw = 10010;      //最大单词数 
 2 const int maxwl = 61;        //最大单词长度 
 3 const int maxl = 1001000;    //最大文本长度 
 4 const int sigm_size = 26;    //字符集大小 
 5 
 6 struct Node {
 7     int sum;//>0表示以该结点为前缀的单词个数,=0表示不是单词结点,=-1表示已经经过计数 
 8     Node* chld[sigm_size];
 9     Node* fail;
10     Node() {
11         sum = 0;
12         memset(chld, 0, sizeof(chld));
13         fail = 0;
14     }
15 };
16 struct ac_automaton {
17     Node* root;
18     void init() {
19         root = new Node;
20     }    
21     int idx(char c) {
22         return c - 'a';
23     }
24     void insert(char *s) {
25         Node* u = root; 
26         for(int i = 0; i < s[i]; i++) {
27             int c = idx(s[i]);
28             if(u->chld[c] == NULL) 
29                 u->chld[c] = new Node;
30             
31             u = u->chld[c];
32         }
33         u->sum++;//以该串为前缀的单词个数++ 
34     }
35     void getfail() {
36         queue<Node*> q;
37         q.push(root);//根结点的fail指针为空 
38         while(!q.empty()) {
39             Node* u = q.front();
40             q.pop();
41             for(int i = 0; i < sigm_size; i++) {                //寻找当前结点的所有非空子结点的fail指针 
42                 if(u->chld[i] != NULL) {
43                     if(u == root)//根结点 
44                         u->chld[i]->fail = root;
45                     else {       //非根节点 
46                         Node* tmp = u->fail;                    //找到它父亲的fail指针指向的结点 
47                         while(tmp != NULL) {                    //向上只有根结点的fail指针是空,所以只要不是根结点就循环 
48                             if(tmp->chld[i] != NULL) {          //直到发现存在一个结点的子结点与其相同 
49                                 u->chld[i]->fail = tmp->chld[i];//就将它的fail指针指向该子结点然后结束循环 
50                                 break;
51                             }
52                             tmp = tmp->fail;//否则一直往上找 
53                         }
54                         if(tmp == NULL)     //如果寻找到根结点还没有找到,就指向根结点,让主串从根结点重新开始匹配 
55                             u->chld[i]->fail = root;
56                     }
57                     q.push(u->chld[i]);     //子结点入队 
58                 }
59             }
60         }
61     }
62     int query(char *t) {
63         
64         int cnt = 0;//文本中存在单词的个数 
65         Node* u = root; 
66         for(int i = 0; t[i]; i++) {//yasherhs 
67             int c = idx(t[i]);
68             while(u != root && u->chld[c] == NULL)//不是根结点而且不匹配,顺着fail指针走,直到可以匹配或者走到根结点 
69                 u = u->fail;
70             
71             u = u->chld[c];  //经过上面的循环,u要么是匹配结点要么是根结点,继续往下走 
72             if(u == NULL)    //如果结点为空,下一个字符重新从根结点开始 
73                 u = root;
74             
75             Node* tmp = u;
76             while(tmp != root) {    //只要没有返回到根结点,就证明在字典树上还存在找到单词的可能 
77                 if(tmp->sum > 0) {
78                     cnt += tmp->sum;//单词计数器加上以当前结点为前缀的单词数 
79                     tmp->sum = -1;  //表示该单词结点已经计过数,防止重复计数 
80                 } 
81                 else                //该单词结点已经匹配过了直接退出,因为后面的状态转移是确定的并且是走过的 
82                     break;
83                 tmp = tmp->fail;    //往其他子树上找 
84             } 
85         }
86         return cnt;
87     }
88 };

  上面的代码中在调用getfail方法之后就构造好了一个AC自动机,具体查询的时候就需要在字典树的状态转移图上进行匹配了。

  具体的匹配过程可分为两种情况:

  1.当前字符匹配,就沿着它的状态转移图往上找,找到单词结点就统计,直到返回到根结点,说明不存在其他单词。

  2.当前字符不匹配,就沿着它的fail指针往上走,直到找到匹配再进入while循环统计单词,或者一直到不到匹配直接跳过。

  如此两种情况交替,直到将文本串遍历完,也就完成了统计。

   

    用上图中的例子来说,统计yasherhs中几个单词表中的单词。

   当i=0,1时,由于Trie中没有对应的路径,故直接跳过;i=2,3,4时,指针u指向右下节点e。因为节点e的sum为1,所以cnt += sum,并将节点e的sum值置为-1,表示该单词已经出现过,避免重复计数,然后tmp指向e节点的失败指针所指向的节点左下e,发现是单词结点cnt  += sum,最后tmp指向root,退出while循环,这个过程中cnt增加了2,表示找到了2个单词she和he。

   当i=5时,u上次指向的是右下e,r不匹配,u指向u的fail指针指向的结点左下e,发现匹配r,u指向左下r,进入下面的while循环,因为左下r的sum为1,所以cnt += sum,表示发现了单词her;

   最后当i=6,7时,找不到任何匹配,查询过程结束(强烈建议手动模拟一下)。

   链表法可以将原理实现直观的转化成代码,不过更常见的是实现起来较为简洁也更为巧妙的转移矩阵法。

    转移矩阵法

    有了转移矩阵建立字典树的基础,然后在字典树上加失配边,代码如下:

 1 struct ac_automaton {
 2     int ch[maxnode][sigm_size];//一个结点对应一个字符集
 3     int fail[maxnode];            //每个结点的fail指针
 4     int val[maxnode];           //每个结点的权值
 5     int sz; 
 6     
 7     void init() {
 8         sz = 0;
 9         newnode();               //创建一个根结点 
10     }
11     int newnode() {
12         memset(ch[sz], -1, sizeof(ch[sz]));
13         val[sz] = 0;
14         return sz++;
15     }
16     int idx(char c) {
17         return c - 'a'; 
18     }
19     void insert(char *s) {
20         int u = 0;
21         for(int i = 0; s[i]; i++) {
22             int c = idx(s[i]);
23             if(ch[u][c] == -1)
24                 ch[u][c] = newnode();
25             
26             u = ch[u][c];
27         }
28         val[u]++;
29     }
30     void getfail() {
31         queue<int> q;
32         fail[0] = 0;            //根结点的fail指针指向它自己也就是空 
33         for(int i = 0; i < sigm_size; i++) {
34             int u = ch[0][i];
35             if(u == -1){        //根结点编号为i的结点不存在时
36                 ch[0][i] = 0;    //把不存在的边补上,将其标记为0
37             }
38             else {                //存在时 
39                 fail[u] = 0;    //失配指针指向根结点并入队 
40                 q.push(u);
41             }
42         }
43         while(!q.empty()) {
44             int u =q.front();
45             q.pop();
46             for(int i = 0; i < sigm_size; i++) { //寻找当前结点u的孩子结点的fail指针 
47                 int tmp = ch[u][i];
48                 if(tmp == -1) 
49                     ch[u][i] = ch[fail[u]][i];    //把不存在的边补上,当前结点u不存在编号为i的孩子时,
50                                  //让它指向当前结点u的fail指针指向的结点对应编号为i的孩子中存的结点编号 
51                 else {
52                     //当前孩子结点的fail指针指向 当前结点u的fail指针指向的结点对应的孩子的编号 
53                     fail[tmp] = ch[fail[u]][i];   
54                     q.push(tmp);
55                 }
56             }
57         }
58     }
59     int query(char *t) {
60         int u = 0, cnt = 0;
61         for(int i = 0; t[i]; i++) {
62             int c = idx(t[i]);
63             u = ch[u][c];      //由于之前把边补齐了,所以可以直接往下走,有匹配直接就是结点,没有匹配直接是根结点 
64              
65             int tmp = u;
66             while(tmp != 0) {  //只要不是根结点,就证明有存在继续找到单词的可能 
67                 cnt += val[tmp];
68                 val[tmp] = 0;
69                 
70                 tmp = fail[tmp];
71             }
72         }
73         return cnt;
74     } 
75 };

  之所以说实现起来较为简单,是因为使用了二维数组,不用指针指来指去;而说更为巧妙是因为当一个结点u不存在哪个编号为i的结点时 就填充为u的fail指针指向的结点对应编号为i的结点编号,如此一来查询的时候就可以直接往下走,而不是需要进入一个循环找到匹配或者根结点再继续。

   这个是根据ACM大佬bin神的AC自动机小结中学来的,仔细体会有种DP的思想在里面。

AC自动机在算法竞赛中的典型应用有哪些?

  基本的问题是给出单词表,给一段文本串,问单词表中的单词存在于文本串中的情况。

    1、存在的单词个数 HDU 2222 Keywords Search

    2、输出存在的单词的编号 HDU 2896 病毒入侵

    3、输出存在单词及其个数 HDU 3065 病毒持续入侵中

    4、单词重叠和不重叠的个数 ZOJ 3288 Searching the String

      5、在二维矩阵中查找小的二维矩阵 UVa 11019 矩阵适配器

  复杂的问题有和DP结合起来的,有和大数结合起来的,有和最短路结合起来的

    1、修改最少次数使得文本串中不包含任何一个模板 HDU 2457 DNA repair

    2、给定n个文本串,m个病毒串,文本串重叠部分可以合并,但合并后不能含有病毒串,问所有文本串合并后最短多长  HDU 3247 Resource Archiver

    3、AC自动机+DP+高精度 POJ 1625 Censored!

例题解析

  HDU 2222 Keywords Search AC自动机入门题,给出单词表和一个文本串,问文本串中有多少个单词表中的单词。首先根据单词表构建字典树,每个单词结点的末尾++,构造AC自动机,匹配文本串统计即可,注意不要忘了将统计过的单词标记一下。

  为了体会AC自动机的基本思想最好两种构建方法都试一下。参考代码如下:

链表法

  1 #include <cstdio>
  2 #include <queue>
  3 #include <cstring> 
  4 using namespace std;
  5 
  6 const int maxw = 10010;        //最大单词数 
  7 const int maxwl = 61;        //最大单词长度 
  8 const int maxl = 1001000;    //最大文本长度 
  9 const int sigm_size = 26;    //字符集大小 
 10 
 11 struct Node {
 12     int sum;//>0表示以该结点为前缀的单词个数,=0表示不是单词结点,=-1表示已经经过计数 
 13     Node* chld[sigm_size];
 14     Node* fail;
 15     Node() {
 16         sum = 0;
 17         memset(chld, 0, sizeof(chld));
 18         fail = 0;
 19     }
 20 };
 21 struct ac_automaton {
 22     Node* root;
 23     void init() {
 24         root = new Node;
 25     }    
 26     int idx(char c) {
 27         return c - 'a';
 28     }
 29     void insert(char *s) {
 30         Node* u = root; 
 31         for(int i = 0; i < s[i]; i++) {
 32             int c = idx(s[i]);
 33             if(u->chld[c] == NULL) 
 34                 u->chld[c] = new Node;
 35             
 36             u = u->chld[c];
 37         }
 38         u->sum++;//以该串为前缀的单词个数++ 
 39     }
 40     void getfail() {
 41         queue<Node*> q;
 42         q.push(root);//根结点的fail指针为空 
 43         while(!q.empty()) {
 44             Node* u = q.front();
 45             q.pop();
 46             for(int i = 0; i < sigm_size; i++) {                //寻找当前结点的所有非空子结点的fail指针 
 47                 if(u->chld[i] != NULL) {
 48                     if(u == root)//根结点 
 49                         u->chld[i]->fail = root;
 50                     else {       //非根节点 
 51                         Node* tmp = u->fail;                    //找到它父亲的fail指针指向的结点 
 52                         while(tmp != NULL) {                    //向上只有根结点的fail指针是空,所以只要不是根结点就循环 
 53                             if(tmp->chld[i] != NULL) {          //直到发现存在一个结点的子结点与其相同 
 54                                 u->chld[i]->fail = tmp->chld[i];//就将它的fail指针指向该子结点然后结束循环 
 55                                 break;
 56                             }
 57                             tmp = tmp->fail;//否则一直往上找 
 58                         }
 59                         if(tmp == NULL)     //如果寻找到根结点还没有找到,就指向根结点,让主串从根结点重新开始匹配 
 60                             u->chld[i]->fail = root;
 61                     }
 62                     q.push(u->chld[i]);     //子结点入队 
 63                 }
 64             }
 65         }
 66     }
 67     int query(char *t) {
 68         int cnt = 0;//文本中存在单词的个数 
 69         Node* u = root; 
 70         for(int i = 0; t[i]; i++) {//yasherhs 
 71             int c = idx(t[i]);
 72             while(u != root && u->chld[c] == NULL)//不是根结点而且不匹配,顺着fail指针走,直到可以匹配或者走到根结点 
 73                 u = u->fail;
 74             
 75             u = u->chld[c]; //经过上面的循环,u要么是匹配结点要么是根结点,继续往下走 
 76             if(u == NULL)    //如果结点为空,下一个字符重新从根结点开始 
 77                 u = root;
 78             
 79             Node* tmp = u;
 80             while(tmp != root) {    //只要没有返回到根结点,就证明在字典树上还存在找到单词的可能 
 81                 if(tmp->sum > 0) {
 82                     cnt += tmp->sum;//单词计数器加上以当前结点为前缀的单词数 
 83                     tmp->sum = -1;  //表示该单词结点已经计过数,防止重复计数 
 84                 } 
 85                 else                //该单词结点已经匹配过了直接退出,因为后面的状态转移是确定的并且是走过的 
 86                     break;
 87                 tmp = tmp->fail;    //往其他子树上找 
 88             } 
 89         }
 90         return cnt;
 91     }
 92 };
 93 
 94 ac_automaton ac;
 95 char txt[maxl];
 96 int main() 
 97 {
 98     int T,n;
 99     char word[maxwl];
100     scanf("%d", &T);
101     while(T--) {
102         ac.init();
103         scanf("%d", &n);
104         for(int i = 0; i < n; i++) {
105             scanf("%s", word);
106             ac.insert(word);
107         }
108         ac.getfail();
109         
110         scanf("%s", txt);
111         printf("%d\n", ac.query(txt));
112     }
113     return 0;
114 } 
View Code

转移矩阵法

  1 #include <cstdio>
  2 #include <cstring> 
  3 #include <queue>
  4 using namespace std;
  5 
  6 const int maxwl = 61;
  7 const int maxw = 10010;
  8 const int maxl = 1001000;
  9 const int sigm_size = 26;
 10 const int maxnode = maxw * maxwl;
 11 
 12 struct ac_automaton {
 13     int ch[maxnode][sigm_size];//一个结点对应一个字符集
 14     int fail[maxnode];            //每个结点的fail指针
 15     int val[maxnode];           //每个结点的权值
 16     int root, sz; 
 17     
 18     void init() {
 19         sz = 0;
 20         root = newnode();               //创建一个根结点 
 21     }
 22     int newnode() {
 23         memset(ch[sz], -1, sizeof(ch[sz]));
 24         val[sz] = 0;
 25         return sz++;
 26     }
 27     int idx(char c) {
 28         return c - 'a'; 
 29     }
 30     void insert(char *s) {
 31         int u = root;
 32         for(int i = 0; s[i]; i++) {
 33             int c = idx(s[i]);
 34             if(ch[u][c] == -1)
 35                 ch[u][c] = newnode();
 36             
 37             u = ch[u][c];
 38         }
 39         val[u]++;
 40     }
 41     void getfail() {
 42         queue<int> q;
 43         fail[root] = root;            //根结点的fail指针指向它自己也就是空 
 44         for(int i = 0; i < sigm_size; i++) {
 45             int u = ch[root][i];
 46             if(u == -1){        //根结点编号为i的结点不存在时
 47                 ch[root][i] = root;    //把不存在的边补上,将其标记为0
 48             }
 49             else {                //存在时 
 50                 fail[u] = root;    //失配指针指向根结点并入队 
 51                 q.push(u);
 52             }
 53         }
 54         while(!q.empty()) {
 55             int u =q.front();
 56             q.pop();
 57             for(int i = 0; i < sigm_size; i++) { //寻找当前结点u的孩子结点的fail指针 
 58                 int tmp = ch[u][i];
 59                 if(tmp == -1) 
 60                     ch[u][i] = ch[fail[u]][i];    //把不存在的边补上,当前结点u不存在编号为i的孩子时,
 61                                                 //让它指向当前结点u的fail指针指向的结点对应编号为i的孩子中存的结点编号 
 62                 else {
 63                     //当前孩子结点的fail指针指向 当前结点u的fail指针指向的结点对应的孩子的编号 
 64                     fail[tmp] = ch[fail[u]][i];   
 65                     q.push(tmp);
 66                 }
 67             }
 68         }
 69     }
 70     int query(char *t) {
 71         int u = root, cnt = 0;
 72         for(int i = 0; t[i]; i++) {
 73             int c = idx(t[i]);
 74             u = ch[u][c];      //由于之前把边补齐了,所以可以直接往下走,有匹配直接就是结点,没有匹配直接是根结点 
 75              
 76             int tmp = u;
 77             while(tmp != 0) { //只要不是根结点,就证明有存在继续找到单词的可能 
 78                 cnt += val[tmp];
 79                 val[tmp] = 0;
 80                 
 81                 tmp = fail[tmp];
 82             }
 83         }
 84         return cnt;
 85     } 
 86 };
 87 
 88 ac_automaton ac; 
 89 char txt[maxl];
 90 int main() 
 91 {
 92     int n, m;
 93     char word[maxwl];
 94     scanf("%d", &n);
 95     while(n--) {
 96         scanf("%d", &m);
 97         ac.init();
 98         for(int i = 0; i < m; i++) {
 99             scanf("%s", word);
100             ac.insert(word);
101         }
102         ac.getfail();
103         
104         scanf("%s", txt);
105         printf("%d\n", ac.query(txt));
106     }
107     return 0;
108 } 
View Code

  HDU 2896 病毒侵袭 给出病毒和多个文本串,输出每个文本串中存在病毒的编号。

  想到怎么记录编号和注意输出格式就没什么大问题。需要知道的是ASCII可见字符是32到126,共95个可见字符。参考代码如下(链表法,请使用C++提交,G++结果MLE,可能G++和C++的内存分配机制不同):

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <queue>
  4 #include <vector>
  5 #include <algorithm> 
  6 using namespace std;
  7 
  8 const int maxwl = 210;
  9 const int maxl = 10010;
 10 const int sigm_size = 128 - 33;
 11 
 12 struct Node {
 13     int num;
 14     Node* fail;
 15     Node* chld[sigm_size];
 16     Node() {
 17         num = 0; 
 18         fail = 0;
 19         memset(chld, 0, sizeof(chld));
 20     }
 21 };
 22 struct ac_automaton {
 23     Node* root;
 24     void init() { 
 25         root = new Node;
 26     }
 27     int idx(char c) {
 28         return c - 32;
 29     }
 30     void insert(char *s, int v) {
 31         Node* u = root;
 32         for(int i = 0; s[i]; i++) {
 33             int c = idx(s[i]);
 34             if(u->chld[c] == NULL)
 35                 u->chld[c] = new Node;
 36             
 37             u = u->chld[c];
 38         }
 39         u->num = v;
 40     }
 41     void getfail() {
 42         queue<Node*> q;
 43         q.push(root);
 44         while(!q.empty()) {
 45             Node* u = q.front();
 46             q.pop();
 47             for(int i = 0; i < sigm_size; i++) {
 48                 if(u->chld[i] != NULL) {
 49                     if(u == root)
 50                         u->chld[i]->fail = root;
 51                     else {
 52                         Node* tmp = u->fail;
 53                         while(tmp != NULL) {
 54                             if(tmp->chld[i] != NULL) {
 55                                 u->chld[i]->fail = tmp->chld[i];
 56                                 break;
 57                             }
 58                             tmp = tmp->fail;
 59                         }
 60                         if(tmp == NULL)
 61                             u->chld[i]->fail = root;
 62                     }
 63                     q.push(u->chld[i]); 
 64                 }
 65             } 
 66         }
 67     }
 68     void query(char *t, vector<int> &p) {
 69         Node* u = root;
 70         for(int i = 0; t[i]; i++) {
 71             int c = idx(t[i]);
 72             while(u != root && u->chld[c] == NULL)
 73                 u = u->fail;
 74                 
 75             u = u->chld[c];
 76             if(u == NULL)
 77                 u = root;
 78             
 79             Node* tmp = u;
 80             while(tmp != root) {
 81                 if(tmp->num > 0) 
 82                     p.push_back(tmp->num);//记录存在的病毒编号 
 83                 
 84                 tmp = tmp->fail;
 85             }
 86         }
 87     }
 88 }ac;
 89 
 90 int main()
 91 {
 92     int n, m;
 93     char word[maxwl], txt[maxl];
 94     while(scanf("%d", &n) != EOF) {
 95         ac.init();
 96         for(int i = 0; i < n; i++) {
 97             scanf("%s", word);
 98             ac.insert(word, i+1);
 99         }
100         ac.getfail(); 
101         
102         scanf("%d", &m);
103         int tot = 0;
104         for(int i = 0; i < m; i++) {
105             scanf("%s", txt);
106             vector<int> p;
107             ac.query(txt, p);
108             if(!p.empty()) {
109                 sort(p.begin(), p.end());
110                 printf("web %d:",i+1);
111                 for(int i = 0; i < p.size(); i++) 
112                     printf(" %d", p[i]);
113                 puts("");
114                 tot++;
115             }
116         }
117         printf("total: %d\n", tot);
118     }
119     return 0;
120 }
View Code

  HDU 3065 病毒侵袭持续中 给出病毒和文本串,输出每个病毒及其存在的次数。

  和上一题很像,注意使用链表法写的时候多样例要释放内存,否则可能会超内存,但是转移矩阵就不会,因此优先选择转移矩阵实现。链表法参考如下(如何递归释放内存):

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <queue>
  4 using namespace std;
  5 
  6 const int maxw = 1010;
  7 const int maxwl = 61;
  8 const int maxl = 2000010;
  9 const int sigm_size = 128 -33;
 10 
 11 char words[maxw][maxwl];
 12 char txt[maxl];
 13 
 14 struct Node {
 15     int flag;
 16     Node* fail;
 17     Node* chld[sigm_size];
 18     Node() {
 19         flag = 0;
 20         fail = 0;
 21         memset(chld, 0, sizeof(chld));
 22     }
 23 };
 24 struct ac_automaton {
 25     Node* root;
 26     int num[maxw];
 27     void init() {
 28         root = new Node;
 29     }
 30     int idx(char c) {
 31         return c - 32;
 32     }
 33     void insert(char *s, int v) {
 34         Node* u = root;
 35         for(int i = 0; s[i]; i++) {
 36             int c = idx(s[i]);
 37             if(u->chld[c] == NULL)
 38                 u->chld[c] = new Node;
 39             u = u->chld[c];
 40         }
 41         u->flag = v;
 42     }
 43     void getfail() {
 44         queue<Node*> q;
 45         q.push(root);
 46         while(!q.empty()) {
 47             Node* u = q.front();
 48             q.pop();
 49             for(int i = 0; i < sigm_size; i++) {
 50                 if(u->chld[i] != NULL) {
 51                     if(u == root)
 52                         u->chld[i]->fail = root;
 53                     else {
 54                         Node* tmp = u->fail;
 55                         while(tmp != NULL) {
 56                             if(tmp->chld[i] != NULL) {
 57                                 u->chld[i]->fail = tmp->chld[i];
 58                                 break;
 59                             }
 60                             tmp = tmp->fail;
 61                         }
 62                         if(tmp == NULL)
 63                             u->chld[i]->fail = root;
 64                     }    
 65                     q.push(u->chld[i]);
 66                 }
 67             }
 68         }
 69     }
 70     void query(char *t, int n) {
 71         Node* u = root;
 72         memset(num, 0, sizeof(num));
 73         for(int i = 0; t[i]; i++) {
 74             int c = idx(t[i]);
 75             while(u != root && u->chld[c] == NULL)
 76                 u = u->fail;
 77             
 78             u = u->chld[c];
 79             if(u == NULL)
 80                 u = root;
 81                 
 82             Node* tmp = u;
 83             while(tmp != root) {
 84                 if(tmp->flag > 0)
 85                     num[tmp->flag]++;
 86                 
 87                 tmp = tmp->fail;
 88             }
 89         }
 90         for(int i = 1; i <= n; i++) {
 91             if(num[i] > 0)
 92                 printf("%s: %d\n", words[i], num[i]);
 93         }
 94     }
 95     void freenode(Node* u) {
 96         if(u == NULL)
 97             return;
 98         for(int i = 0; i < sigm_size; i++) 
 99             freenode(u->chld[i]);
100         delete u;
101     }
102 }ac;
103 
104 int main() 
105 {
106     int n;
107     char word[maxwl];
108     while(scanf("%d", &n) != EOF) {
109         ac.init();
110         for(int i = 1; i <= n; i++) {
111             scanf("%s", words[i]);
112             ac.insert(words[i], i);
113         }
114         ac.getfail();
115         
116         scanf("%s", txt);
117         ac.query(txt, n);
118         ac.freenode(ac.root);//多样例测试时别忘记释放内存 
119     }
120     return 0;
121 }
View Code

  ZOJ 3228 Searching the String 先给出文本串,再给出多个单词,但询问方式不同,0表示可以重叠存在的次数,1表示不可重叠存在的次数。

  重叠的询问好求,一遍AC自动机解决,关键是不可重叠次数。设想,如果我们能够记录一个单词上一次在文本串中的匹配位置,那么当前单词结点的末尾在文本串中的位置 - 当前单词结点在文本串中上一次匹配的位置  大于等于以当前字符结尾的单词结点的长度时,表示不重叠。可以使用一个二维数组记录每个单词两种询问的答案,最后查询输出。参考代码如下:

  1 #include <cstdio>
  2 #include <queue>
  3 #include <cstring> 
  4 using namespace std;
  5 
  6 const int sigm_size = 26;                //字符集的大小 
  7 const int maxl = 100010;                //文本串的长度 
  8 const int maxw = 10010;                    //单词个数 
  9 const int maxwl = 7;                    //单词长度 
 10 const int maxnode = maxw * maxwl * 10;    //字典树结点数 = 单词数乘以单词长度乘以10 
 11 
 12 char txt[maxl];
 13 int node[maxl];                        //记录每个单词在字典树中单词结点的编号 
 14 int op[maxl];                        //每个单词的查询方式
 15                          
 16 struct ac_automaton {
 17     int ch[maxnode][sigm_size], fail[maxnode];
 18     int pos[maxnode];    //记录以当前字符结尾的单词的长度 
 19     int L, root;
 20     void init() {
 21         L = 0;
 22         root = newnode();
 23     } 
 24     int newnode() {
 25         memset(ch[L], -1, sizeof(ch[L]));
 26         pos[L++] = 0;    //以当前字符结尾的单词长度为0 
 27         return L-1; 
 28     }
 29     int idx(char c) {
 30         return c - 'a';
 31     }
 32     void insert(char *s, int v) {
 33         int now = root;
 34         for(int i = 0; s[i]; i++) {
 35             int c = idx(s[i]);
 36             if(ch[now][c] == -1)
 37                 ch[now][c] = newnode();
 38             now = ch[now][c];
 39             pos[now] = i+1;//以当前字符结尾的单词的长度 
 40         }
 41         node[v] = now;//编号为v的模式串在字典树中的序号 
 42     } 
 43     void getfail() {
 44         queue<int> q;
 45         fail[root] = root;
 46         for(int i = 0; i < sigm_size; i++) {
 47             if(ch[root][i] == -1)
 48                 ch[root][i] = root;
 49             else {
 50                 fail[ch[root][i]] = root;
 51                 q.push(ch[root][i]);
 52             }
 53         }
 54         while(!q.empty()) {
 55             int now = q.front();
 56             q.pop();
 57             for(int i = 0; i < sigm_size; i++) {
 58                 if(ch[now][i] == -1) 
 59                     ch[now][i] = ch[fail[now]][i];
 60                 else {
 61                     fail[ch[now][i]] = ch[fail[now]][i];
 62                     q.push(ch[now][i]);
 63                 }
 64             }
 65         } 
 66     }
 67     int ans[maxnode][2];                //标号为i的单词的重叠和不重叠的个数 
 68     int last[maxnode];                    //记录当前单词结点在文本串中的上一个匹配位置
 69     void query(char *t) {
 70         memset(last, -1, sizeof(last));
 71         memset(ans, 0, sizeof(ans));
 72         int now = root;
 73         for(int i = 0; t[i]; i++) {
 74             int c = idx(t[i]);
 75             now = ch[now][c];
 76             int tmp = now;
 77             while(tmp != root) {
 78                 ans[tmp][0] ++;
 79                 /*
 80                     当前字符的位置 - 当前单词结点在文本串中上一次匹配的位置 
 81                     大于等于以当前字符结尾的单词结点的长度时,表示不重叠 
 82                 */ 
 83                 if(i - last[tmp] >= pos[tmp]) { 
 84                     ans[tmp][1] ++;
 85                     last[tmp] = i;//记录当前单词结点在文本串中的位置 
 86                 }
 87                 tmp = fail[tmp];
 88             }
 89         } 
 90     }
 91 }ac;
 92 
 93 int main() 
 94 {
 95     int n, kase = 1;
 96     char word[maxwl];
 97     while(scanf("%s", txt) != EOF) {
 98         scanf("%d", &n);
 99         ac.init();
100         for(int i = 0; i < n; i++) {
101             scanf("%d%s", &op[i], word);
102             ac.insert(word, i); 
103         }
104         ac.getfail();
105         ac.query(txt);
106         
107         printf("Case %d\n", kase++);
108         for(int i = 0; i < n; i++) {
109             printf("%d\n", ac.ans[node[i]][op[i]]);
110         }
111         puts("");
112     }    
113     return 0;
114 }
View Code

  UVA 11019 Matrix Matcher AC自动机应用的二维推广。给出一个大的二维字符矩阵T,一个小的二维矩阵P,问P在T中存在的次数。

  思路很简单,使用一个二维矩阵cnt,如果cnt[r][c]表示T中以(r,c)为左上角、与P等大的矩形有多少个完整的行和P对应位置的行完全相同。当P的第j行出现在T的第r行、起始列号为i时,意味着cnt[r-j+1][i-y+2]++,其中具体加几和存储的起始位置有关,按照自己的规则即可。所有匹配结束后,那些cnt[r][c] == x(P的行数)的点就是一个二维匹配点。

  另外需要注意的是P中可能存在重复,存在重复的模板会导致字典树中结点编号覆盖,所以使用一个vector数组保存所有的编号。参考代码如下:

  1 #include <vector>
  2 #include <cstdio>
  3 #include <queue>
  4 #include <cstring> 
  5 using namespace std;
  6 
  7 const int maxn = 1100;
  8 const int maxw = 110;
  9 const int maxwl = 110;
 10 const int maxnode = maxw * maxwl;
 11 const int sigm_size = 26;
 12 
 13 struct ac_automaton {
 14     int cnt[maxn][maxn];
 15     int ch[maxnode][sigm_size];
 16     int fail[maxnode];
 17     vector<int> val[maxnode];  
 18     int sz, root;
 19     
 20     void init() {
 21         sz = 0;
 22         root = newnode();
 23         memset(cnt, 0, sizeof(cnt));
 24     }
 25     int newnode() {
 26         memset(ch[sz], -1, sizeof(ch[sz]));
 27         val[sz].clear();
 28         return sz++;
 29     }
 30     int idx(char c) {
 31         return c - 'a';
 32     }
 33     void insert(char *s, int v) {
 34         int u = root;
 35         for(int i = 0; s[i]; i++) {
 36             int c = idx(s[i]);
 37             if(ch[u][c] == -1)
 38                 ch[u][c] = newnode();
 39             
 40             u = ch[u][c];
 41         }
 42         val[u].push_back(v);//以该结点为末尾的p的行编号 
 43     }
 44     void getfail() {
 45         queue<int> q;
 46         fail[root] = root;
 47         for(int i = 0; i < sigm_size; i++) {
 48             if(ch[root][i] == -1)
 49                 ch[root][i] = root;
 50             else {
 51                 fail[ch[root][i]] = root;
 52                 q.push(ch[root][i]);
 53             }
 54         }
 55         while(!q.empty()) {
 56             int u = q.front();
 57             q.pop();
 58             for(int i = 0; i < sigm_size; i++) {
 59                 if(ch[u][i] == -1) 
 60                     ch[u][i] = ch[fail[u]][i];
 61                 else {
 62                     fail[ch[u][i]] = ch[fail[u]][i];
 63                     q.push(ch[u][i]);
 64                 }
 65             }
 66         }
 67     }
 68     void query(char *t, int r, int y) {
 69         int u = root;
 70         for(int i = 0; t[i]; i++) {
 71             int c = idx(t[i]);
 72             u = ch[u][c];//走到u结点 
 73             
 74             for(int k = 0; k < val[u].size(); k ++){//遍历以该结点为结尾的p的每一个行编号 
 75                 int j = val[u][k];
 76                 //如果当前行T的第r行 - P的第j行 + 1 > 0,也就是在左上(1,1)到右下(n,m)这个区域内 
 77                 if(r-j+1>0) cnt[r-j+1][i-y+2]++;
 78                 //其中+1或者+2是数据存储问题引起,二维数组从第1行第0列开始 
 79             }    
 80         }
 81     }
 82     int count(int n, int m, int x) {
 83         int ans = 0;
 84         for(int i = 1; i <= n ; i++) {
 85             for(int j = 1; j <= m; j++) {
 86                 if(cnt[i][j] == x)
 87                     ans ++;
 88             }
 89         }
 90         return ans;
 91     }
 92 }ac;
 93 
 94 char t[maxn][maxn], p[maxn/10][maxn/10];
 95 int n, m, x, y;
 96 
 97 int main()
 98 {
 99     int T;
100     scanf("%d", &T);
101     while(T--) {
102         scanf("%d%d", &n, &m);
103         for(int i = 1; i <= n; i++) {
104             scanf("%s", t[i]);
105         }
106         ac.init();
107         scanf("%d%d", &x, &y);
108         for(int i = 1; i <= x; i++) {
109             scanf("%s", p[i]);
110             ac.insert(p[i], i);
111         }
112         ac.getfail();
113         for(int i = 1; i <= n; i++) 
114             ac.query(t[i], i, y);
115         printf("%d\n", ac.count(n,m,x));
116     }
117     return 0;
118 } 
View Code

  POJ 3691 DNA repair 给出单词表和一个文本串,问最少修改几个字符使得该文本串不包含所有的单词。

  先根据单词表构建一个AC自动机,具体匹配的时候我们可以定义一个状态dp[i][j]表示长度为i、以字典树中j号结点结尾的字符串不包含所有单词所需的最少修改次数。很容易递推发现,dp[i+1][u]也就是长度为1+1、以当前结点结尾的字符串的最小修改次数等于 u的所有孩子结点ch[j][k]是否和当前s[i]相等的最小值。参考代码如下:

  1 #include <cstdio>
  2 #include <queue>
  3 #include <algorithm>
  4 #include <cstring>
  5 using namespace std;
  6 
  7 const int inf = 0x3f3f3f3f;
  8 const int maxw = 60;
  9 const int maxwl = 30;
 10 const int maxl = 1100;
 11 const int maxnode = maxw * maxwl;
 12 const int sigm_size = 4;
 13 
 14 struct ac_automaton {
 15     int ch[maxnode][sigm_size];
 16     int fail[maxnode];
 17     bool val[maxnode];
 18     int root, sz;
 19     
 20     void init() {
 21         sz = 0;
 22         root = newnode();
 23     }
 24     int newnode() {
 25         memset(ch[sz], -1, sizeof(ch[sz]));
 26         val[sz] = false;
 27         return sz++;
 28     }
 29     int idx(char c) {
 30         if(c == 'A')
 31             return 0;
 32         if(c == 'C')
 33             return 1;
 34         if(c == 'G')
 35             return 2;
 36         if(c == 'T')
 37             return 3;
 38     }
 39     void insert(char *s) {
 40         int u = root;
 41         for(int i = 0; s[i]; i++) {
 42             int c = idx(s[i]);
 43             if(ch[u][c] == -1)
 44                 ch[u][c] = newnode();
 45             u = ch[u][c];
 46         }
 47         val[u] = true;
 48     }
 49     void getfail() {
 50         queue<int> q;
 51         fail[root] = root;
 52         for(int i = 0; i < sigm_size; i++) {
 53             if(ch[root][i] == -1)
 54                 ch[root][i] = root;
 55             else {
 56                 fail[ch[root][i]] = root;
 57                 q.push(ch[root][i]);
 58             }
 59         }
 60         while(!q.empty()) {
 61             int u = q.front();
 62             q.pop();
 63             if(val[fail[u]]) val[u] = true;
 64             
 65             for(int i = 0; i < sigm_size; i++) {
 66                 if(ch[u][i] == -1) 
 67                     ch[u][i] = ch[fail[u]][i];
 68                 else {
 69                     fail[ch[u][i]] = ch[fail[u]][i];
 70                     q.push(ch[u][i]);
 71                 }
 72             }
 73         }
 74     }
 75     
 76     int dp[maxnode][maxnode];
 77     //定义状态dp[i][j]表示长度为i、以字典树上结点编号为j的字符结尾的字符串 所需的最小修改次数 
 78     int solve(char *s) {
 79         int len = strlen(s);
 80         for(int i = 0; i <= len; i ++) {//初始化大小为len * sz大小的空间 
 81             for(int j = 0; j < sz; j++) {
 82                 dp[i][j] = inf;
 83             }
 84         }
 85         dp[0][root] = 0;//初始化长度为0,以根结点结尾的字符串 所需最小修改次数为 0 
 86         for(int i = 0; i <= len; i++) {
 87             for(int j = 0; j < sz; j++) {
 88                 //之前一次拓展没有更新表示该长度以j结尾的字符串存在病毒结点,故直接跳过 
 89                 if(dp[i][j] >= inf) continue;
 90                 
 91                     for(int k = 0; k < 4; k++) {
 92                         int u = ch[j][k];
 93                         if(val[u]) continue;//当前结点j的孩子中有的是病毒结点直接跳过 
 94                         int tmp;
 95                         if(k == idx(s[i]))
 96                             tmp = dp[i][j];
 97                         else
 98                             tmp = dp[i][j] + 1;
 99                         //更新长度加一、以孩子结点u结尾的的状态 
100                         dp[i+1][u] = min(dp[i+1][u], tmp);
101                     }
102                 
103             }
104         }
105         int ans = inf;
106         for(int i = 0; i < sz; i++) 
107             ans = min(dp[len][i], ans);
108         if(ans == inf)
109             return -1;
110         return ans;
111     }
112 }ac;
113 
114 int main() 
115 {
116     int n, kase = 1;
117     char word[maxwl], txt[maxl];
118     while(scanf("%d", &n) == 1 && n) {
119         ac.init();
120         for(int i = 0; i < n; i++) {
121             scanf("%s", word);
122             ac.insert(word);
123         }
124         ac.getfail();
125         
126         scanf("%s", txt);
127         printf("Case %d: %d\n",kase++, ac.solve(txt));
128     }
129     return 0;
130 }
View Code

  还有其他综合型的题目,有兴趣的同学自行尝试(很刺激,一题坑一天的都是少的那种)。

  至此,AC自动机解析及其在竞赛中的典型应用就总结完了,算法很精妙,关键是体会算法的基本思想,加上一些具体的应用实践,才能掌握牢固。AC自动机有很多变形,要想学好,用好,还需掌握其他知识,比如矩阵加速,高精度,状压DP(省略很多我还不知道的算法)。算法学习并非易事,要坚持思考,实践,总结才行。(原创不易,转载请注明出处哦)

 

posted @ 2018-08-11 19:08  Reqaw  阅读(3506)  评论(1编辑  收藏  举报