BZOJ3172[Tjoi2013]单词——AC自动机(fail树)

题目描述

某人读论文,一篇论文是由许多单词组成。但他发现一个单词会在论文中出现很多次,现在想知道每个单词分别在论文中出现多少次。

输入

第一个一个整数N,表示有多少个单词,接下来N行每行一个单词。每个单词由小写字母组成,N<=200,单词长度不超过10^6

输出

输出N个整数,第i行的数字表示第i个单词在文章中出现了多少次。

样例输入

3
a
aa
aaa

样例输出

6
3
1
 
这道题题干真是言简意赅,看了半天愣是没看明白。为了防止有人也像我一样没看懂,在这里解释下题目及样例:文章由输入的几个单词组成,但并不是把这几个字符串连一起。对于询问的第i个单词出现几次是指这个单词在每个单词中出现次数加和(包括自己)。例如样例中a在第一个单词中出现1次,在第二个中出现2次,在第三个中出现3次;aa在第一个中没有,第二个第三个中分别出现1次、2次。aaa只在第三个中出现1次。对于第i个单词在第j个单词中出现几次就相当于问j单词中有几个节点直接或间接指向i单词的终止节点,也就是问在fail树中以i单词终止节点为根的子树中有几个节点是j单词串上的点。fail树是什么?fail树就是由每个点失配标记连向这个点所形成的树。在建AC自动机时要记录每个点被遍历几次作为这个点的权值表示这个点是几个单词串上的点,最后dfs一遍fail树就好了。
最后附上代码。
#include<cmath>
#include<queue>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
int n;
int num;
int tot;
int cnt;
int g[300];
char s[1000010];
int to[1000010];
int sum[1000010];
int fail[1000010];
int next[1000010];
int head[1000010];
int a[1000010][26];
void add(int x,int y)
{
    tot++;
    next[tot]=head[x];
    head[x]=tot;
    to[tot]=y;
}
void build(char *s)
{
    int now=0;
    int len=strlen(s);
    for(int i=0;i<len;i++)
    {
        if(!a[now][s[i]-'a'])
        {
            a[now][s[i]-'a']=++cnt;
        }
        now=a[now][s[i]-'a'];
        sum[now]++;
    }
    g[++num]=now;
}
void getfail()
{
    queue<int>q;
    for(int i=0;i<26;i++)
    {
        if(a[0][i])
        {
            fail[a[0][i]]=0;
            q.push(a[0][i]);
        }
    }
    while(!q.empty())
    {
        int now=q.front();
        q.pop();
        for(int i=0;i<26;i++)
        {
            if(a[now][i])
            {
                fail[a[now][i]]=a[fail[now]][i];
                q.push(a[now][i]);
            }
            else
            {
                a[now][i]=a[fail[now]][i];
            }
        }
    }
}
void dfs(int x)
{
    for(int i=head[x];i;i=next[i])
    {
        dfs(to[i]);
        sum[x]+=sum[to[i]];
    }
    return ;
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
    {
        scanf("%s",s);
        build(s);
    }
    getfail();
    for(int i=1;i<=cnt;i++)
    {
        add(fail[i],i);
    }
    dfs(0);
    for(int i=1;i<=n;i++)
    {
        printf("%d\n",sum[g[i]]);
    }
}
posted @ 2018-06-10 21:43  The_Virtuoso  阅读(343)  评论(3编辑  收藏  举报