一个自己编写的简单AC自动机代码-----AC automata get √

最近一直在优化项目中字符串匹配的问题,于是就想起了自动机,之前也看过一些文章,一直没有实现,现在项目中要用,然后又看了一些关于AC自动机的文章,这里实现了一个简单的AC自动机的小接口,我是实现自动机状态结构采用了trie树,实现起来简单一些,但在一定程度上造成空间复杂度的增加,欢迎大家纠错和一起交流经验,下面是代码。

  1 #include <stdio.h>
  2 #include <stdlib.h>
  3 #include <string.h>
  4 
  5 #define CHAR_EXTERN 26 //trie的next指针的最大数目,全字符时应设置为256然后和0xff进行&运算
  6 
  7 typedef struct ac_node ac_node;
  8 typedef struct ac_node* ac_node_p;
  9 
 10 #define QUEUE_TYPE ac_node_p //进行此宏定义,就可以把queue封装成接口使用
 11 #define FREE_QUEUE_VALUE     //queue->value为动态申请内存需要做此操作    
 12 
 13 int queue_node_num = 0;
 14 
 15 //定义队列结构用于计算失效指针
 16 typedef struct queue_node
 17 {
 18     QUEUE_TYPE value;
 19     struct queue_node *next;
 20 }queue_node;
 21 
 22 
 23 typedef struct queue
 24 {
 25     queue_node *head;
 26     queue_node *tail;
 27 }queue;
 28 
 29 /*
 30  * queue创建结点
 31  */
 32 queue_node *queue_build_node(QUEUE_TYPE value)
 33 {
 34     queue_node *node = (queue_node *)malloc(sizeof(queue_node));
 35 
 36     if(node == NULL)
 37     {
 38 #ifdef AC_DEBUG
 39         printf("queue node bulid error memory is full !!\n\n");
 40 #endif        
 41         return NULL;
 42     }
 43     
 44     node->value = value;
 45 
 46     node->next = NULL;
 47     
 48     return node;
 49 }
 50 
 51 /*
 52  * queue初始化
 53  * return -1 失败 else 成功
 54  */
 55 queue *queue_init()
 56 {
 57     queue *ac_queue = (queue *)malloc(sizeof(queue));
 58     
 59     if(ac_queue == NULL)
 60     {
 61 #ifdef AC_DEBUG
 62         printf("queue build failed memory is full\n\n");
 63 #endif        
 64         return NULL;
 65     }    
 66 
 67     ac_queue->head = ac_queue->tail = queue_build_node(NULL);
 68     
 69     if(ac_queue->head == NULL)
 70     {
 71 #ifdef AC_DEBUG
 72         printf("queue build head error memory is full\n\n");
 73 #endif
 74         return NULL;
 75     }
 76 
 77     //ac_queue->head->next = ac_queue->tail;
 78 
 79     return ac_queue;
 80 }
 81 
 82 /*
 83  *queue 为空判断
 84  *return 1 为空 0 不为空
 85  */
 86 int queue_is_empty(queue *ac_queue)
 87 {
 88     if(ac_queue->head == ac_queue->tail)
 89     {
 90         return 1;
 91     }
 92 
 93     return 0;
 94 }
 95 
 96 /*
 97  * queue 向队尾添加结点
 98  */
 99 void queue_insert(queue *ac_queue,queue_node *node)
100 {
101     ac_queue->tail->next = node;
102     ac_queue->tail = node;
103     queue_node_num++;
104 #ifdef AC_DEBUG
105     printf("after insert the queue node num is %d :\n",queue_node_num);    
106 #endif
107 }
108 
109 /*
110  *queue 提取队首结点value值
111  */
112 QUEUE_TYPE queue_first(queue *ac_queue)
113 {
114     if(queue_is_empty(ac_queue))
115     {
116 #ifdef AC_DEBUG
117         printf("the queue is empty can not return head!!\n");
118 #endif
119         return NULL;
120     }
121     
122     return ac_queue->head->next->value; //队首不存值,从队首的下一个结点开始取值
123 }
124 
125 /*
126  *queue 队首结点出队列
127  */
128 void queue_delete(queue *ac_queue)
129 {
130     
131     if(queue_is_empty(ac_queue))
132     {
133 #ifdef AC_DEBUG
134         printf("the queue is empty we can not delete head\n\n");
135 #endif
136         return;
137     }
138 
139     queue_node *head = ac_queue->head->next; //队首不存值,从队首的下一个结点开始出队列
140     
141     ac_queue->head->next = head->next;
142     
143     if(head == ac_queue->tail)
144     {//出队列为最后一个元素时将队列置空
145         ac_queue->tail = ac_queue->head;    
146     }
147 
148     free(head); //释放队首结点内存
149 
150 #ifdef AC_DEBUG
151     queue_node_num--;    
152     printf("after delete the queue node num is %d :\n",queue_node_num);    
153 #endif
154 }
155 
156 /*
157  *queue 释放queue内存
158  */
159 void queue_destroy(queue *ac_queue)
160 {
161     queue_node *p = NULL;
162     p = ac_queue->head;
163     
164     while(p != NULL)
165     {
166 
167 #ifdef FREE_QUEUE_VALUE
168         if(p->value != NULL)
169             free(p->value); //value为动态申请内存的情况下做此操作    
170 #endif
171         queue_node *tmp = p->next;
172         
173         if(p != NULL)
174             free(p);
175         
176         p = tmp;
177     }
178 }
179 
180 //ac状态节点
181 struct ac_node
182 {
183     int final; //是否为一个模式串结尾的表示
184     int model; //标识该模式串为哪个模式串(如果考虑后缀子模式,此处应该改为整型链表)
185     ac_node *fail; //该状态节点的失效指针
186 
187     struct ac_node *next[CHAR_EXTERN];
188 };
189 
190 /*
191  * 创建状态节点
192  */
193 ac_node *ac_node_build()
194 {
195     int i;
196     ac_node *node = (ac_node *)malloc(sizeof(ac_node));
197 
198     if(node == NULL)
199     {
200 #ifdef AC_DEBUG
201         printf("bulid node error the memory is full !! \n\n");
202 #endif        
203         return NULL;
204     }
205 
206     node->final = 0;
207     node->model = -1;
208     node->fail = NULL;
209 
210     for(i = 0; i < CHAR_EXTERN; i++)
211     {
212         node->next[i] = NULL;
213     }
214 
215     return node;
216 }
217 
218 
219 /*
220  * 创建trie树
221  * return -1 失败 else 成功
222  */
223 int ac_trie_build(ac_node *root,char *str,int len,int model)
224 {
225     int i;
226     ac_node *tmp = root;
227     
228     if(tmp == NULL)
229     {
230 #ifdef AC_DEBUG
231         printf("root has not been init!!! \n\n");
232 #endif
233         return -1;
234     }    
235 
236     for(i = 0; i < len; i++)
237     {
238 
239         /*
240          ac_node *next_node = tmp->next[str[i] - 'a'];
241          这样写然后对next_node操作会造成trie树建立失败,由于next_node一直不为NULL
242          */
243 
244         int index = str[i] - 'a'; // if CHAR_EXTERN=256 index = str[i]&0xff
245         if(tmp->next[index] == NULL)
246         {
247             tmp->next[index] = ac_node_build();
248             
249             if(tmp->next[index] == NULL)
250             {
251 #ifdef AC_DEBUG
252                 printf("build node error in ac_trie_build !!\n");
253 #endif
254                 return -1;
255             }
256 
257         }
258         
259         tmp = tmp->next[index];
260     }
261 
262     tmp->final = 1;
263     tmp->model = model;
264 
265     return 0;
266 }
267 
268 
269 /*
270 * 创建失效指针函数
271 */
272 
273 void ac_build_fail(ac_node *root,queue *ac_queue)
274 {
275     if(root == NULL || ac_queue == NULL)
276     {
277 #ifdef AC_DEBUG
278         printf("build ac fail pointer error -- input\n");
279 #endif    
280         return ;
281     }
282 
283     int i;
284     queue_node *q_node = NULL;
285     ac_node *tmp_node = NULL;
286     ac_node *fail_node = NULL;
287 
288     q_node = queue_build_node(root);
289     queue_insert(ac_queue,q_node);
290 
291     while(!queue_is_empty(ac_queue))
292     {
293         tmp_node = queue_first(ac_queue);
294 #ifdef AC_DEBUG
295         printf("out the queue the ac node pointer is %p \n",tmp_node);
296 #endif
297         queue_delete(ac_queue);//队首元素出队列
298 
299         for(i = 0; i < CHAR_EXTERN; i++)
300         {//通过队列采用BFS(广度优先)的遍历顺序来计算当前状态每个字符的失效函数
301 
302             if(tmp_node->next[i] != NULL) // if CHART_EXTERN=255 tmpnode->next[i&0xff]
303             {
304                 if(tmp_node == root)
305                 {
306                     tmp_node->next[i]->fail = root; //第一层节点的失效指针指向根结点
307                 }
308                 else
309                 {
310                     fail_node = tmp_node->fail; //父结点的失效指针
311 
312                     while(fail_node != NULL)
313                     {//当直到回到根结点时
314                         if(fail_node->next[i] != NULL)
315                         {//在父结点的失效指针中找到当前字符的外向边
316                             tmp_node->next[i]->fail = fail_node->next[i];
317                             break;
318                         }
319 
320                         fail_node = fail_node->fail; //继续递归
321                     }
322 
323                     if(fail_node == NULL)
324                     {//找不到失效指针,则失效指针为根结点
325                         tmp_node->next[i]->fail = root;
326                     }
327                     
328                 }
329 
330                 q_node = queue_build_node(tmp_node->next[i]);
331                 queue_insert(ac_queue,q_node); //将当前层的结点插入队列继续进行广度优先遍历
332 #ifdef AC_DEBUG
333                 printf("insert into a ac node into queue the state is : %c \n\n", i + 'a');
334                 printf("insert the ac node pointer is %p\n",tmp_node->next[i]);
335 #endif
336             }
337         }        
338     }
339 }
340 
341 /*
342  *模式匹配函数
343  *return -1 未匹配到任何模式 else 匹配到的当前的模式串的值
344  */
345 int ac_query(ac_node *root,char *str,int len)
346 {
347     if(root == NULL || str == NULL)
348     {
349         return -1;
350     }
351 
352     int i,match_num = 0;
353     int index = 0;
354 
355     ac_node *tmp_node = NULL;
356     ac_node *p = root;
357     
358     for(i = 0; i < len; i++)
359     {
360         index = str[i] - 'a'; // if CHAR_EXTERN=256 index = str[i]&0xff    
361 
362         while(p->next[index] == NULL && p != root)
363         {//状态在当前字符不进行转移以后的失效转移
364             p = p->fail;
365         }
366         
367         p = p->next[index]; //将状态进行下移
368 
369         if(p == NULL)
370         {//未找到失效指针,从根结点开始继续
371             p = root;
372         }
373         
374         tmp_node = p; //计算当前状态下的匹配情况(在局部定义指针型变量经常出现内存指向问题)
375 
376         while(tmp_node != root)
377         {
378             if(tmp_node->final == 1)
379             {
380                 match_num++;
381                 
382                 //tmp_node = tmp_node->fail; //匹配子模式串
383 
384                 return tmp_node->model; //此处不进行return 可以计算多个模式串
385             }
386 
387             tmp_node = tmp_node->fail; //匹配子模式串
388         }
389 
390     }
391 
392     return -1;
393 }
394 
395 /*
396  *打印trie树
397  */
398 void ac_trie_print(ac_node *root)
399 {//利用队列进行广度优先遍历并打印trie树
400     ac_node *ac_queue[256];
401     int i,head,tail;
402     head = tail = 0;
403 
404     memset(ac_queue,0,sizeof(ac_queue));
405 
406     ac_node *tmp = root;
407 
408     ac_queue[tail++] = tmp;
409     while(head != tail)
410     {
411         tmp = ac_queue[head++]; //出队列
412         for(i = 0; i < CHAR_EXTERN; i++)
413         {
414             if(tmp->next[i] != NULL)
415             {
416                 printf("%c  ",i+'a');
417                 ac_queue[tail++] = tmp->next[i]; //将当前层的所有节点入队列
418             }
419         }
420         
421         printf("\n");
422             
423     }
424 }
425 
426 
427 #define LIB_MODEL_NUM 10
428 
429 int main()
430 {
431 #if 1
432     char *lib_model_str[10] = {"wwwgooglecom",\
433                  "wwwbaiducom",\
434                  "www",\
435                  "googlecom",\
436                  "guoxin",\
437                          "she",\
438                          "her",\
439                          "shr",\
440                          "yes",\
441                  "say"};
442 
443 #endif
444 
445 #if 0
446 
447     char *lib_model_str[LIB_MODEL_NUM] = {
448         "she",\
449         "he",\
450         "say",\
451         "shr",\
452         "her"\
453     };
454 #endif
455 
456     int i;
457     char str[1024];
458     queue *ac_queue = NULL;
459     ac_node *root = NULL; 
460 
461     ac_queue = queue_init();
462     root = ac_node_build();
463 
464     for(i = 0; i < LIB_MODEL_NUM; i++)
465     {
466         ac_trie_build(root,lib_model_str[i],strlen(lib_model_str[i]),i);    
467     }
468     
469 
470     ac_trie_print(root);
471         
472     ac_build_fail(root,ac_queue);
473     
474     while(1)
475     {
476         printf("input the match string:\n\n");
477         
478         scanf("%s",str);
479 
480         if(strcmp(str,"quit") == 0)
481         {
482             return -1;
483         }
484 
485         int model = ac_query(root,str,strlen(str));
486 
487         if(model == -1)
488         {
489             printf("not match!!\n\n");
490         }
491         else
492         {
493             printf("match the model '%d' and the model string is %s \n\n",model,lib_model_str[model]);
494         }
495     }
496 
497     return 0;
498 }

 

posted @ 2015-09-09 11:07  代码的搬运工  阅读(598)  评论(0编辑  收藏  举报