bzoj2780

AC自动机+树链剖分+线段树/树状数组+dfs序+树链的并

题意:给出n个母串和q个询问串,对于每个询问串输出有多少个母串包含这个询问串 N=∑|母串|<=10^5 Q=∑|询问串|<=3.6*10^5

由于是母串包含询问串,那么我们就对询问串建自动机,然后用母串在上面跑,跑一次的复杂度是母串的长度(不确定)

利用fail树的性质求解,每次母串跑到一个节点,就把这个节点到根的路径都加上1,说明母串当前匹配单词的一段后缀包含某个前缀,也就是母串包含某个前缀

由于是求出现次数,那么询问串多次出现在母串里只算一次,所以之前到根的路径上加重了,所以我们利用树链的并去重,将节点按dfs序排序,相邻两个节点的lca到根的路径-1,这样就使得每个节点到根的路径都等于1

然后就真写了一个树链剖分+单点查询,其实直接查子树就行了,直接在节点上打上+1标记,lca上打上-1标记,树状数组维护子树和就行了

#include<bits/stdc++.h>
using namespace std;
const int N = 400010;
int n, q;
vector<char> v[N];
char s[N];
struct ac_automation {    
    int root, cnt, tot, dfs_clock;
    int child[N][27], fail[N], top[N], fa[N], son[N], dep[N], size[N], dfn[N], mir[N], tree[N << 2], pos[N], tag[N << 2];
    vector<int> G[N], p;
    void insert(char s[], int id)
    {
        int len = strlen(s), now = root;
        for(int i = 0; i < len; ++i)
        {
            int t = s[i] - 'a';
            if(child[now][t] == 0) child[now][t] = ++cnt;
            now = child[now][t];
        }
        pos[id] = now;
    }
    void build_fail()
    {
        queue<int> q;
        for(int i = 0; i < 26; ++i) if(child[root][i]) 
        {
            q.push(child[root][i]);
            G[root].push_back(child[root][i]);
        }
        while(!q.empty())
        {
            int u = q.front();
            q.pop();
            for(int i = 0; i < 26; ++i)
            {
                int &v = child[u][i];
                if(v == 0) v = child[fail[u]][i];
                else
                {
                    fail[v] = child[fail[u]][i];
                    q.push(v);        
                    G[child[fail[u]][i]].push_back(v);
                }
            }
        }
    }
    int lca(int u, int v)
    {
        while(top[u] != top[v])
        {
            if(dep[top[u]] < dep[top[v]]) swap(u, v);
            u = fa[top[u]];
        }
        return dep[u] < dep[v] ? u : v;
    }
    void dfs(int u)
    {
        size[u] = 1;
        for(int i = 0; i < G[u].size(); ++i)
        {
            int v = G[u][i];
            dep[v] = dep[u] + 1;
            fa[v] = u;
            dfs(v);
            size[u] += size[v];
            if(size[v] >= size[son[u]]) son[u] = v;
        }
    } 
    void dfs(int u, int acs)
    {
        dfn[u] = ++dfs_clock;
        mir[dfn[u]] = u;
        top[u] = acs;
        if(son[u]) dfs(son[u], acs);
        for(int i = 0; i < G[u].size(); ++i)
        {
            int v = G[u][i];
            if(v == son[u]) continue;
            dfs(v, v);
        }
    }
    void pushdown(int x, int l, int r)
    {
        if(tag[x] == 0) return;
        int mid = (l + r) >> 1;
        tag[x << 1] += tag[x];
        tag[x << 1 | 1] += tag[x];
        tree[x << 1] += tag[x] * (mid - l + 1);
        tree[x << 1 | 1] += tag[x] * (r - mid);
        tag[x] = 0;
    }
    int query(int l, int r, int x, int pos)
    {
        if(l == r) return tree[x];
        pushdown(x, l, r);
        int mid = (l + r) >> 1;
        if(pos <= mid) return query(l, mid, x << 1, pos);
        else return query(mid + 1, r, x << 1 | 1, pos);
    }
    void update(int l, int r, int x, int a, int b, int delta)
    {
        if(l > b || r < a) return;
        if(l >= a && r <= b)
        {
            tag[x] += delta;
            tree[x] += (r - l + 1) * delta;
            return;
        }
        pushdown(x, l, r);
        int mid = (l + r) >> 1;
        update(l, mid, x << 1, a, b, delta);
        update(mid + 1, r, x << 1 | 1, a, b, delta);
        tree[x] = tree[x << 1] + tree[x << 1 | 1];
    }
    void change(int u, int delta)
    {
        while(top[u])
        {
            update(1, cnt + 1, 1, dfn[top[u]], dfn[u], delta);
            u = fa[top[u]];
        }
        update(1, cnt + 1, 1, 1, dfn[u], delta);
    }
    void put_string(int id)
    {
        int len = v[id].size(), now = root;
        p.clear();
        for(int i = 0; i < len; ++i)
        {
            now = child[now][v[id][i] - 'a'];
            p.push_back(dfn[now]);
        }
        sort(p.begin(), p.end());
        p.erase(unique(p.begin(), p.end()), p.end());
        for(int i = 0; i < p.size(); ++i)
        {
            int u = p[i];
            change(mir[u], 1);
        }
        for(int i = 1; i < p.size(); ++i)
        {
            int u = p[i], v = p[i - 1];
            change(lca(mir[u], mir[v]), -1);
        }
    }
    int ask(int id)
    {
        return query(1, cnt + 1, 1, dfn[pos[id]]);
    }
} ac;
int main()
{
    scanf("%d%d", &n, &q);
    for(int i = 1; i <= n; ++i) 
    {
        scanf("%s", s);
        int len = strlen(s);
        for(int j = 0; j < len; ++j) v[i].push_back(s[j]);
    }
    for(int i = 1; i <= q; ++i) 
    {
        scanf("%s", s);
        ac.insert(s, i);
    }
    ac.build_fail();
    ac.dfs(0);
    ac.dfs(0, 0);
    for(int i = 1; i <= n; ++i) ac.put_string(i);
    for(int i = 1; i <= q; ++i) printf("%d\n", ac.ask(i));
    return 0;
}
View Code

 

posted @ 2017-08-13 15:41  19992147  阅读(153)  评论(0编辑  收藏  举报