【bzoj3277/bzoj3473】串/字符串 广义后缀自动机

题目描述

字符串是oi界常考的问题。现在给定你n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串(注意包括本身)。

输入

第一行两个整数n,k。
接下来n行每行一个字符串。

输出

输出一行n个整数,第i个整数表示第i个字符串的答案。

样例输入

3 1
abc
a
ab

样例输出

6 1 3


题解

广义后缀自动机

建立广义后缀自动机,统计一下每个节点属于多少个字符串中。

在每次插入新节点np后,parent树上np的父亲节点fa[np]一定有着np的身份,即一定包含当前串(这里将root节点视作在所有串中)。

那么沿着parent树一直向上更新即可。更新到已经被更新过的节点就停止。

这样时间复杂度是O(n)的。

求出每个节点在几个子串中以后,还要知道出现次数。

而出现次数有一个定理:x结尾的不重复字符串的出现次数=dis[x]-dis[fa[x]]。

简单证明一下:从root走到parent树上x到root的链上的点共有dis[x]种方法,走到fa[x]到root的链上的点共有dis[fa[x]]种方法,相减可推出结论。

而仅仅求出不重复字符串的出现次数还不够,我们需要统计所有的出现位置。而出现位置一定是在parent树中从x到叶子结点均出现的。

那么可以用类似于树形dp的方法求解。

总结一下:建立广义后缀自动机,同时统计每个节点在多少个字符串中出现过。保留出现在≥k个字符串中的,令出现次数为dis[x]-dis[fa[x]]。然后树形dp,c[x]+=c[fa[x]],并更新答案求解。

说了这么多,其实代码也就几十行~

#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 200010
using namespace std;
int a[N][26] , fa[N] , kind[N] , type[N] , si[N] , last , tot = 1;
int head[N] , to[N] , next[N] , cnt;
long long dis[N] , c[N] , ans[N];
char str[N];
void add(int x , int y)
{
    to[++cnt] = y , next[cnt] = head[x] , head[x] = cnt;
}
void ins(int t , int c)
{
    int p = last , np = last = ++tot;
    dis[np] = dis[p] + 1 , kind[np] = t;
    while(p && !a[p][c]) a[p][c] = np , p = fa[p];
    if(!p) fa[np] = 1;
    else
    {
        int q = a[p][c];
        if(dis[q] == dis[p] + 1) fa[np] = q;
        else
        {
            int nq = ++tot;
            memcpy(a[nq] , a[q] , sizeof(a[q])) , dis[nq] = dis[p] + 1 , fa[nq] = fa[q] , fa[np] = fa[q] = nq , si[nq] = si[q] , type[nq] = type[q];
            while(p && a[p][c] == q) a[p][c] = nq , p = fa[p];
        }
    }
    for(p = np ; p && type[p] != t ; p = fa[p]) si[p] ++ , type[p] = t;
}
void dfs(int x)
{
    int i;
    c[x] += c[fa[x]] , ans[kind[x]] += c[x];
    for(i = head[x] ; i ; i = next[i])
        dfs(to[i]);
}
int main()
{
    int n , m , k , i , j;
    scanf("%d%d" , &n , &k);
    for(i = 1 ; i <= n ; i ++ )
    {
        scanf("%s" , str + 1) , m = strlen(str + 1) , last = 1;
        for(j = 1 ; j <= m ; j ++ ) ins(i , str[j] - 'a');
    }
    for(i = 2 ; i <= tot ; i ++ )
    {
        add(fa[i] , i);
        if(si[i] >= k) c[i] = dis[i] - dis[fa[i]];
    }
    dfs(1);
    for(i = 1 ; i <= n ; i ++ ) printf("%lld " , ans[i]);
    return 0;
}

 

 

posted @ 2017-06-06 15:32  GXZlegend  阅读(649)  评论(0编辑  收藏  举报