「模拟赛20180306」回忆树 memory LCA+KMP+AC自动机+树状数组

题目描述

回忆树是一棵树,树边上有小写字母。

一次回忆是这样的:你想起过往,触及心底……唔,不对,我们要说题目。

这题中我们认为回忆是这样的:给定 \(2\) 个点 \(u,v\) (\(u\) 可能等于 \(v\))和一个非空字符串 \(s\) ,问从 \(u\)\(v\) 的简单路径上的所有边按照到 \(u\) 的距离从小到大的顺序排列后,询问边上的字符依次拼接形成的字符串中给定的串 \(s\) 出现了多少次。

输入

第一行 \(2\) 个整数,依次为树中点的个数 \(n\) 和回忆的次数 \(m\)
接下来 \(n-1\) 行,每行 \(2\) 个整数 \(u,v\)\(1\) 个小写字母 \(c\) ,表示回忆树的点\(u,v\)之间有一条边,边上的字符为\(c\)
接下来 \(2m\) 行表示 \(m\) 次回忆,每次回忆 \(2\) 行:第 \(1\)\(2\) 个整数 \(u,v\),第 \(2\) 行给出回忆的字符串 \(s\)

输出

对于每次回忆,输出串 \(s\) 出现的次数。

样例

样例输入

12 3
1 2 w
2 3 w
3 4 x
4 5 w
5 6 w
6 7 x
7 8 w
8 9 w
9 10 x
10 11 w
11 12 w
1 7
wwx
1 12
www
1 12
w

样例输出

2
0
8

数据范围

\(1≤n,m≤10^5\)
询问字符串的总长度不超过\(3\times10^5\)

题解

这是一道神题,做法优美而且巧妙(同时也很恶心)。

既然是树链上的询问,就不能不让人想到利用\(LCA\)\(u\xrightarrow{}v\)的路径转化成\(u\xrightarrow{}lca\)\(lca\xrightarrow{}v\)的两条路径了。

那么我们就可以把询问分成三部分。

  1. \(lca\xrightarrow{}u\)\(s\)的反串出现了多少次
  2. \(lca\xrightarrow{}v\)\(s\)出现了多少次
  3. 跨越\(lca\)时,\(s\)出现了多少次

可以发现,第一部分和第二部分其实是类似的问题,我们先放一放。


那么我们考虑第三个问题,好像没有什么很简单的方法,于是我们考虑暴力。
很容易发现这一种情况下涉及的字符串不长,只有\(u\xrightarrow{}lca\)路径上的\(\left|s\right|\)个和\(v\xrightarrow{}lca\)路径上的\(\left|s\right|\)个。我们可以暴力取出这一段字符,然后做一次\(KMP\),这样一次的复杂度是\(O(\left|s\right|)\),总时间复杂度就是\(O(\sum\left|s\right|)\),完全可以过。


现在就剩前两个问题了。我们发现询问串太多,一个个做显然很吃力,这时,\(AC\)自动机的方法就呼之欲出了。我们把所有询问串做成一个\(AC\)自动机,把整棵树带进去匹配即可。

匹配的过程很简单,模拟字符串匹配的时候即可,从根开始,依次访问子树,进栈的时候答案加,出栈的时候答案减即可,然后把询问的区间标记一下,到达合适的区间就计算答案。

但是这样还有一个问题,\(AC\)自动机上的答案是要给\(fail\)链上的所有点增加的,暴力加显然会超时。于是我们修改一下做法,预处理出\(fail\)树的先序遍历序列,然后建立树状数组(一个比较显然的性质,同一颗子树的遍历序列是连续的)。于是修改的时候单点修改,查询的时候查询\(fail\)树上的子树和即可。


然而,这道题说起来很轻巧,却是一道码农题……并且还卡常数……卡常数!!!
所以,我还是把我\(250\)行的代码拿出来吧……
\(Code:\)

#include <queue> 
#include <vector> 
#include <cstdio> 
#include <cstring> 
#include <algorithm> 
using namespace std; 
#define M 600005 
queue<int>q; 
int n, m; 
int f[25][M], dep[M], fa[M]; 
int L[M], R[M], ans[M], ens[M], plc[M]; 
vector<int>B[M], E[M]; 
char len[M], top[M], S[M]; 
struct node 
{ 
    int fir[M], tar[M], nex[M], cnt; 
}T1, T2; 
void add(int a, int b, char c) 
{ 
    ++T1.cnt; 
    T1.tar[T1.cnt] = b; 
    len[T1.cnt] = c; 
    T1.nex[T1.cnt] = T1.fir[a]; 
    T1.fir[a] = T1.cnt; 
} 
void add(int a, int b) 
{ 
    ++T2.cnt; 
    T2.tar[T2.cnt] = b; 
    T2.nex[T2.cnt] = T2.fir[a]; 
    T2.fir[a] = T2.cnt; 
} 
//dfs-begin 
void dfs(int r) 
{ 
    for (int i = T1.fir[r]; i; i = T1.nex[i]) 
    { 
        int v = T1.tar[i]; 
        if (v != fa[r]) 
        { 
            fa[v] = r; 
            dep[v] = dep[r] + 1; 
            top[v] = len[i]; 
            dfs(v); 
        } 
    } 
} 
//dfs-end 
//LCA-begin 
int LCA(int u, int v) 
{ 
    if (dep[u] < dep[v]) 
        swap(u, v); 
    int k = dep[u] - dep[v]; 
    for (int i = 20; i >= 0; i--) 
        if (k & 1 << i) 
            u = f[i][u]; 
    if (u == v) 
        return u; 
    for (int i = 20; i >= 0; i--) 
        if (f[i][u] != f[i][v]) 
            u = f[i][u], v = f[i][v]; 
    return f[0][u]; 
} 
int getk(int u, int k) 
{ 
    for (int i = 0; i <= 20; i++) 
        if (k & 1 << i) 
            u = f[i][u]; 
    return u; 
} 
//LCA-end 
//KMP-begin 
char K[M]; 
int nex[M]; 
void KMP(int a, int b, int c, int ls, int w) 
{ 
    int len = 0; 
    while (a != c) 
        K[len++] = top[a], a = fa[a]; 
    int z = dep[b] - dep[c]; 
    len += dep[b] - dep[c]; 
    while (b != c) 
        K[--len] = top[b], b = fa[b]; 
    len += z; 
    K[len] = 0; 
    nex[0] = -1; 
    int i = 0, j = -1, ans = 0; 
    while(i < ls) 
    { 
        if (j == -1 || S[i] == S[j]) 
            nex[++i] = ++j; 
        else
            j = nex[j]; 
    } 
    i = 0, j = 0; 
    while(i < len) 
    { 
        if (j == ls) 
        { 
            ans++; 
            j = nex[j]; 
            continue; 
        } 
        if(j == -1 || K[i] == S[j]) 
            i++, j++; 
        else
            j = nex[j]; 
    } 
    if (j == ls) 
        ans++; 
    ens[w] += ans; 
} 
//KMP-end 
//ACTrie-begin 
struct ACTrie 
{ 
    int nex[M][30], fail[M], in[M], out[M]; 
    int root, cnt, tim, dfn[M], id[M]; 
    int tree[M]; 
    ACTrie(){root = cnt = 1;} 
    void Insert(char *S, int w) 
    { 
        int r = root, len = strlen(S); 
        for (int i = 0; i < len; i++) 
        { 
            int val = S[i] - 'a'; 
            if (!nex[r][val]) 
                nex[r][val] = ++cnt; 
            r = nex[r][val]; 
        } 
        plc[w] = r; 
    } 
    void Build() 
    { 
        int r = root; 
        fail[r] = r; 
        q.push(root); 
        while (!q.empty()) 
        { 
            r = q.front(); 
            q.pop(); 
            for (int i = 0; i < 26; i++) 
            { 
                if (nex[r][i]) 
                { 
                    int tmp = nex[fail[r]][i]; 
                    if (tmp && tmp != nex[r][i]) 
                        fail[nex[r][i]] = tmp; 
                    else
                        fail[nex[r][i]] = root; 
                    q.push(nex[r][i]); 
                } 
                else
                { 
                    int tmp = nex[fail[r]][i]; 
                    if (tmp) 
                        nex[r][i] = tmp; 
                    else
                        nex[r][i] = root; 
                } 
            } 
            if (r != root) 
                add(fail[r], r); 
        } 
    } 
    void DFS(int r) 
    { 
        dfn[r] = ++tim; 
        in[r] = tim; 
        id[tim] = r; 
        for (int i = T2.fir[r]; i; i = T2.nex[i]) 
        { 
            int v = T2.tar[i]; 
            DFS(v); 
        } 
        out[r] = tim; 
    } 
    void Update(int x, int v) 
    { 
        for (int i = x; i <= cnt; i += i & -i) 
            tree[i] += v; 
    } 
    int Getsum(int x) 
    { 
        int ans = 0; 
        for (int i = x; i; i -= i & -i) 
            ans += tree[i]; 
        return ans; 
    } 
}AC; 
//ACTrie-end 
void dfs2(int r, int now) 
{ 
    AC.Update(AC.dfn[now], 1); 
    int s = B[r].size(); 
    for (int i = 0; i < s; i++) 
        ens[(B[r][i] + 1)/ 2] -= AC.Getsum(AC.out[plc[B[r][i]]]) - AC.Getsum(AC.in[plc[B[r][i]]] - 1); 
    s = E[r].size(); 
    for (int i = 0; i < s; i++) 
        ens[(E[r][i] + 1)/ 2] += AC.Getsum(AC.out[plc[E[r][i]]]) - AC.Getsum(AC.in[plc[E[r][i]]] - 1); 
    for (int i = T1.fir[r]; i; i = T1.nex[i]) 
    { 
        int v = T1.tar[i]; 
        if (v != fa[r]) 
            dfs2(v, AC.nex[now][len[i] - 'a']); 
    } 
    AC.Update(AC.dfn[now], -1); 
} 
int main() 
{ 
    //freopen("memory.in", "r", stdin); 
    //freopen("memory.out", "w", stdout); 
    scanf("%d%d", &n, &m); 
    for (int i = 1; i < n; i++) 
    { 
        int a, b; 
        char c[5]; 
        scanf("%d%d%s", &a, &b, c); 
        add(a, b, c[0]); 
        add(b, a, c[0]); 
    } 
    dfs(1); 
    for (int i = 1; i <= n; i++) 
        f[0][i] = fa[i]; 
    for (int i = 1; i <= 20; i++) 
        for (int j = 1; j <= n; j++) 
            f[i][j] = f[i - 1][f[i - 1][j]]; 
    int w = 0; 
    for (int i = 1; i <= m; i++) 
    { 
        int u, v, c; 
        scanf("%d%d%s", &u, &v, S); 
        c = LCA(u, v); 
        int l1 = dep[u] - dep[c], l2 = dep[v] - dep[c], ls = strlen(S); 
        int a = getk(u, max(0, l1 - ls + 1)); 
        int b = getk(v, max(0, l2 - ls + 1)); 
        KMP(a, b, c, ls, i); 
        w++; 
        AC.Insert(S, w); 
        B[b].push_back(w); 
        E[v].push_back(w); 
        w++; 
        for (int i = 0; i < ls / 2; i++) 
            swap(S[i], S[ls - i - 1]); 
        AC.Insert(S, w); 
        B[a].push_back(w); 
        E[u].push_back(w); 
    } 
    AC.Build(); 
    AC.DFS(AC.root); 
    dfs2(1, AC.root); 
    for (int i = 1; i <= m; i++) 
        printf("%d\n", ens[i]); 
} 

posted @ 2018-03-09 14:24  ModestStarlight  阅读(272)  评论(0编辑  收藏  举报