P3808 【模板】AC自动机(简单版)

\(\color{#0066ff}{ 题目描述 }\)

给定n个模式串和1个文本串,求有多少个模式串在文本串里出现过。

\(\color{#0066ff}{输入格式}\)

第一行一个n,表示模式串个数;

下面n行每行一个模式串;

下面一行一个文本串。

\(\color{#0066ff}{输出格式}\)

一个数表示答案

\(\color{#0066ff}{输入样例}\)

2
a
aa
aa

\(\color{#0066ff}{输出样例}\)

2

\(\color{#0066ff}{数据范围与提示}\)

subtask1[50pts]:∑length(模式串)<=106,length(文本串)<=106,n=1;

subtask2[50pts]:∑length(模式串)<=106,length(文本串)<=106;

\(\color{#0066ff}{ 题解 }\)

AC自动机板子

近期得知有一个Trie图优化,要不然容易被卡虽然从来没被卡过

而且优化后的及其好写

正常来说,一个节点的fail等于父亲的fail链上最近的这个儿子

我们需要不断跳fail,这样很慢

优化:如果一个节点没有c儿子,就让它连上fail链上最近的c儿子

这个可以递推出来

这样,一个点的fail就是父亲fail的这个儿子,就没了,也不用暴力跳了

而且匹配的时候,可以发现,不会有空节点,一直跳就行了

#include<bits/stdc++.h>
#define LL long long
LL in() {
	char ch; LL x = 0, f = 1;
	while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
	for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
	return x * f;
}
struct AC {
protected:
	struct node {
		node *ch[26];
		int num;
		node *fail;
		node() {
			memset(ch, 0, sizeof ch);
			fail = NULL;
			num = 0;
		}
		void *operator new (size_t) {
			static node *S = NULL, *T = NULL;
			return (S == T) && (T = (S = new node[1024]) + 1024), S++;
		}
	};
	node *root;
public:
	AC() { root = new node(); }
	void ins(char *s) {
		node *o = root;
		for(char *p = s; *p; p++) {
			int pos = *p - 'a';
			if(!o->ch[pos]) o->ch[pos] = new node();
			o = o->ch[pos];
		}
		o->num++;
	}
	void build() {
		std::queue<node*> q;
		q.push(root);
		root->fail = root;
		while(!q.empty()) {
			node *tp = q.front(); q.pop();
			for(int i = 0; i <= 25; i++) {
				if(tp == root) {
					if(tp->ch[i]) tp->ch[i]->fail = root, q.push(tp->ch[i]);
					else tp->ch[i] = root;}
				else {
					if(tp->ch[i]) tp->ch[i]->fail = tp->fail->ch[i], q.push(tp->ch[i]);
					else tp->ch[i] = tp->fail->ch[i];
				}
			}
		}
	}
	int query(char *s) {
		node *o = root;
		int ans = 0;
		for(char *p = s; *p; p++) {
			int pos = *p - 'a';
			o = o->ch[pos];
			if(~o->num) ans += o->num, o->num = -1;
		}
		return ans;
	}
}b;
const int maxn = 1e6 + 100;
char s[maxn];
int main() {
	int n = in();
	for(int i = 1; i <= n; i++) scanf("%s", s), b.ins(s);
	b.build();
	scanf("%s", s);
	printf("%d\n", b.query(s));
	return 0;
}
posted @ 2019-01-10 14:15  olinr  阅读(158)  评论(0编辑  收藏  举报