bzoj4545

lct+SAM

bzoj4516+bzoj2555

这道题唯一的用处就是教会了我真正的广义SAM

dfs时保留当前节点在后缀自动机中的位置,每个点接着父亲建

lct动态维护right集合大小,用lct维护子树就行了

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 1e5 + 5;
struct edge {
    int nxt, to, w;
} e[N << 1];
int n, m, cnt = 1;
int h[N], pos[N];
char s[N];
long long sum;
namespace lct 
{
    struct node {
        int ch[2];
        int f, tag, reg;
    } t[N << 1];
    bool isr(int x) {
        return !t[x].f || (t[t[x].f].ch[0] != x && t[t[x].f].ch[1] != x);
    }
    int wh(int x) {
        return x == t[t[x].f].ch[1];
    }
    void paint(int x, int d) {
        t[x].tag += d;
        t[x].reg += d;
    }
    void pushdown(int x) {
        if(t[x].tag) {
            paint(t[x].ch[0], t[x].tag);
            paint(t[x].ch[1], t[x].tag);
            t[x].tag = 0;
        }
    }
    void pd(int x) {
        if(!isr(x)) {
            pd(t[x].f);
        }        
        pushdown(x);
    }
    void rotate(int x) {
        int y = t[x].f, z = t[y].f, w = wh(x);
        if(!isr(y)) {
            t[z].ch[wh(y)] = x;
        }
        t[x].f = z;
        t[y].ch[w] = t[x].ch[w ^ 1];
        t[t[x].ch[w ^ 1]].f = y;
        t[y].f = x;
        t[x].ch[w ^ 1] = y;
    }
    void splay(int x) {
        pd(x);
        for(; !isr(x); rotate(x)) {
            if(!isr(t[x].f)) {
                rotate(wh(t[x].f) == wh(x) ? t[x].f : x); 
            }
        }
    }
    void access(int x) {
        for(int y = 0; x; y = x, x = t[x].f) {
            splay(x);
            t[x].ch[1] = y;
        }
    }
    void link(int u, int v) {
        access(u);
        splay(u);
        access(v);
        splay(v);
        paint(v, t[u].reg);
        t[u].f = v;
    }
    void cut(int u, int v) {
        access(u);
        splay(u);
        t[u].ch[0] = 0;
        t[v].f = 0;
        paint(v, -t[u].reg);
    }
}
namespace SAM 
{
    struct node {
        int ch[26];
        int par, val;
    } t[N << 1];
    int root = 1, sz = 1;
    int nw(int x) {
        t[++sz].val = x;
        return sz;
    }
    int extend(int last, int c) {
        int p = last, np = nw(t[p].val + 1);
        lct::t[np].reg = 1;
        while(p && !t[p].ch[c]) {
            t[p].ch[c] = np;
            p = t[p].par;
        }
        if(!p) {
            t[np].par = root;
            lct::link(np, root);
        } else {
            int q = t[p].ch[c];
            if(t[q].val == t[p].val + 1) {
                t[np].par = q;
                lct::link(np, q);
            } else {
                int nq = nw(t[p].val + 1);
                lct::link(nq, t[q].par); 
                lct::cut(q, t[q].par);
                lct::link(np, nq);
                lct::link(q, nq);
                t[nq].par = t[q].par;
                t[np].par = t[q].par = nq;
                memcpy(t[nq].ch, t[q].ch, sizeof(t[q].ch));
                while(p && t[p].ch[c] == q) {
                    t[p].ch[c] = nq;
                    p = t[p].par;
                }
            }
        }
        sum += t[np].val - t[t[np].par].val;
        return np;
    }
    void solve(char *s) {
        int len = strlen(s), now = root;
        for(int i = 0; i < len; ++i) {
            if(!t[now].ch[s[i] - 'a']) {
                puts("0");
                return;
            }
            now = t[now].ch[s[i] - 'a'];
        }
        lct::pd(now);
        printf("%d\n", lct::t[now].reg);
        lct::splay(now);
    }
}
void link(int u, int v, int w) {
    e[++cnt].nxt = h[u];
    h[u] = cnt;
    e[cnt].to = v;
    e[cnt].w = w;
}
void dfs(int u, int last) {
    for(int i = h[u]; i; i = e[i].nxt) {
        if(e[i].to == last) {
            continue;
        }
        pos[e[i].to] = SAM::extend(pos[u], e[i].w);
        dfs(e[i].to, u);
    }
    h[u] = 0;
}
int main() {
//    freopen("1.out", "w", stdout);
    int laji;
    scanf("%d%d", &laji, &n);
    for(int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d%s", &u, &v, s);
        link(u, v, s[0] - 'a');
        link(v, u, s[0] - 'a');
    }
    pos[1] = 1;
    dfs(1, 0);
    scanf("%d", &m);
    while(m--) {
        int opt;
        scanf("%d", &opt);
        if(opt == 1) {
            printf("%lld\n", sum);
        } else if(opt == 2) {
            int rt, sz;
            scanf("%d%d", &rt, &sz);
            while(--sz) {
                int u, v;
                scanf("%d%d%s", &u, &v, s);
                link(u, v, s[0] - 'a');
                link(v, u, s[0] - 'a');
            }
            dfs(rt, 0);
        } else {
            scanf("%s", s);
            SAM::solve(s);
        }
    }
    return 0;
} 
View Code

 

posted @ 2018-03-01 09:11  19992147  阅读(206)  评论(0编辑  收藏  举报