@bzoj - 1921@ [ctsc2010]珠宝商


@description@

简述版题意:给定字符串 S 与一棵树 T,树上每个点有一个字符。求树上所有简单路径对应的字符串在 S 中的出现次数之和。

原题链接。

@solution@

一个显然的暴力:O(N^2) 枚举所有点对的字符串,建后缀自动机跑算出现次数之和。

另一个看起来比较好的算法:点分治。
每次计算经过重心的字符串,可以拆成两部分:某个点到重心的字符串 + 重心到某个点的字符串。
不过合并的时候需要在原字符串 S 上找分界点(不然无法合并),该算法执行一次复杂度为 O(M)。

接下来?发现并没有什么更好的性质可以利用。
数据范围比较小,我们不妨考虑一些玄学的操作:平衡复杂度。

在点分治时,如果当前连通块大小 < \(\sqrt{M}\) 则执行第一种暴力,否则执行第二种暴力。这样平衡下来复杂度为 \(O(N\sqrt{M})\)
看起来比较显然:连通块大小 < \(\sqrt{M}\) 时 O(size^2) 优于 O(M);否则 O(M) 优于 O(size^2)。
至于复杂度的正确性,第一种暴力的总和显然 \(O(N\sqrt{M})\)。第二种暴力由于决策树的叶子个数 <= \(O(\sqrt{M})\),而深度为 logN(点分治),所以也是 \(O(N\sqrt{M})\)

不过需要注意点分治时,容斥减去同一子树的贡献也需要根据子树大小分类讨论。

提一点细节:我们找某个点到重心的字符串是往前加字符,并以该字符串为后缀,更新 S 的前缀。也就是说我们不能跑 DAG,需要直接在 parent 树上跑。
其实也不是很难处理,不过状态需要存成两部分:所在结点与现长度。
转移时分两种情况考虑,一个是长度依然小于等于所在结点表示字符串的最大长度,直接在原字符串 S 上看加入这个字符是否仍然合法;另一种,我们需要处理出每个结点的最长字符串前面加入某个字符会转移到的点(可以根据儿子找父亲)。

@accepted code@

#include <cmath>
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;

typedef long long ll;

const int MAXN = 100000;

#define mp make_pair
#define fi first
#define se second

struct SAM{
    char str[MAXN + 5]; int n;
    struct node{
        int len, pos, cnt, tag;
        node *sn[26], *ch[26], *fa;
    }pl[MAXN + 5], *ncnt, *lst, *root;
    SAM() {ncnt = lst = root = pl;}
    node *extend(int x, int ps) {
        node *nw = (++ncnt), *p = lst; lst = nw;
        nw->len = p->len + 1, nw->pos = ps, nw->cnt = 1;
        while( p && p->ch[x] == NULL )
            p->ch[x] = nw, p = p->fa;
        if( !p ) nw->fa = root;
        else {
            node *q = p->ch[x];
            if( p->len + 1 == q->len )
                nw->fa = q;
            else {
                node *nq = (++ncnt); (*nq) = (*q);
                nq->len = p->len + 1, nq->cnt = 0;
                nw->fa = q->fa = nq;
                while( p && p->ch[x] == q )
                    p->ch[x] = nq, p = p->fa;
            }
        }
        return nw;
    }
    node *nd[MAXN + 5];
    int a[MAXN + 5], b[MAXN + 5];
    void build(int _n) {
        n = _n;
        for(int i=0;i<n;i++) nd[i] = extend(str[i] - 'a', i);
        for(int i=ncnt-pl;i>=1;i--) {
            node *p = &pl[i];
            p->fa->sn[str[p->pos - p->fa->len] - 'a'] = p;
        }
        for(int i=1;i<=ncnt-pl;i++) b[pl[i].len]++;
        for(int i=1;i<=n;i++) b[i] += b[i-1];
        for(int i=1;i<=ncnt-pl;i++) a[b[pl[i].len]--] = i;
        for(int i=ncnt-pl;i>=1;i--) pl[a[i]].fa->cnt += pl[a[i]].cnt;
    }
    int f[MAXN + 5];
    void clear() {
        for(int i=0;i<n;i++) f[i] = 0;
        for(int i=0;i<=ncnt-pl;i++) pl[i].tag = 0;
    }
    void get() {
        for(int i=1;i<=ncnt-pl;i++) pl[a[i]].tag += pl[a[i]].fa->tag;
        for(int i=0;i<n;i++) f[i] = nd[i]->tag;
    }
    pair<node*, int>trans(pair<node*, int>x, int ch) {
        if( x.fi == NULL ) return x;
        x.se++;
        if( x.se > x.fi->len ) {
            x.fi = x.fi->sn[ch];
            return x;
        }
        else {
            if( str[x.fi->pos - x.se + 1] - 'a' != ch )
                x.fi = NULL;
            return x;
        }
    }
    void update(node *k) {k->tag++;}
}S1, S2;

struct edge{
    int to; edge *nxt;
}edges[MAXN + 5], *adj[MAXN + 5], *ecnt = edges;
void addedge(int u, int v) {
    edge *p = (++ecnt);
    p->to = v, p->nxt = adj[u], adj[u] = p;
    p = (++ecnt);
    p->to = u, p->nxt = adj[v], adj[v] = p;
}
#define rep(x) for(edge *p=adj[x];p;p=p->nxt)

int N, M, SQ; ll ans;
char s[MAXN + 5];

bool vis[MAXN + 5]; int siz[MAXN + 5];
int get_size(int x, int fa) {
    siz[x] = 1;
    rep(x) {
        if( vis[p->to] || p->to == fa ) continue;
        siz[x] += get_size(p->to, x);
    }
    return siz[x];
}
int hvy[MAXN + 5];
int get_G(int x, int fa, int tot) {
    int ret = -1; hvy[x] = tot - siz[x];
    rep(x) {
        if( vis[p->to] || p->to == fa ) continue;
        int t = get_G(p->to, x, tot);
        hvy[x] = max(hvy[x], siz[p->to]);
        if( ret == -1 || hvy[t] < hvy[ret] ) ret = t;
    }
    if( ret == -1 || hvy[x] < hvy[ret] ) ret = x;
    return ret;
}
void dfs2(int x, int f, SAM::node *nw, int type) {
    nw = nw->ch[s[x] - 'a'];
    if( !nw ) return ;
    ans += nw->cnt * type;
    rep(x) {
        if( vis[p->to] || p->to == f ) continue;
        dfs2(p->to, x, nw, type);
    }
}
void dfs1(int x, int f) {
    dfs2(x, -1, S1.root, 1);
    rep(x) {
        if( vis[p->to] || p->to == f ) continue;
        dfs1(p->to, x);
    }
}

typedef pair<SAM::node*, int> pr;

void dfs3(int x, int f, pr nw) {
    nw = S1.trans(nw, s[x] - 'a');
    if( nw.fi == NULL ) return ;
    S1.update(nw.fi);
    rep(x) {
        if( vis[p->to] || p->to == f ) continue;
        dfs3(p->to, x, nw);
    }
}
void dfs4(int x, int f, pr nw) {
    nw = S2.trans(nw, s[x] - 'a');
    if( nw.fi == NULL ) return ;
    S2.update(nw.fi);
    rep(x) {
        if( vis[p->to] || p->to == f ) continue;
        dfs4(p->to, x, nw);
    }
}

SAM::node *a[MAXN + 5]; int cnt;
void dfs5(int x, int f, pr nw) {
    nw = S1.trans(nw, s[x] - 'a');
    if( nw.fi == NULL ) return ;
    a[++cnt] = nw.fi;
    rep(x) {
        if( vis[p->to] || p->to == f ) continue;
        dfs5(p->to, x, nw);
    }
}
void divide(int x, int n) {
    if( n <= SQ ) dfs1(x, -1);
    else {
        vis[x] = true;
        S1.clear(), dfs3(x, -1, mp(S1.root, 0)), S1.get();
        S2.clear(), dfs4(x, -1, mp(S2.root, 0)), S2.get();
        for(int i=0;i<M;i++) ans += 1LL*S1.f[i]*S2.f[M-1-i];
        rep(x) {
            if( vis[p->to] ) continue;
            int k = get_size(p->to, -1);
            if( k <= SQ ) {
                cnt = 0, dfs5(p->to, x, S1.trans(mp(S1.root, 0), s[x] - 'a'));
//              int res = ans;
                for(int i=1;i<=cnt;i++) dfs2(p->to, -1, a[i], -1);
//              int del = 0;
//              S1.clear(), dfs3(p->to, -1, S1.trans(mp(S1.root, 0), s[x] - 'a')), S1.get();
//              S2.clear(), dfs4(p->to, -1, S2.trans(mp(S2.root, 0), s[x] - 'a')), S2.get();
//              for(int i=0;i<M;i++) del += 1LL*S1.f[i]*S2.f[M-1-i];
//              printf("%d %d\n", res - ans, del);
            }
            else {
                S1.clear(), dfs3(p->to, -1, S1.trans(mp(S1.root, 0), s[x] - 'a')), S1.get();
                S2.clear(), dfs4(p->to, -1, S2.trans(mp(S2.root, 0), s[x] - 'a')), S2.get();
                for(int i=0;i<M;i++) ans -= 1LL*S1.f[i]*S2.f[M-1-i];
            }
            divide(get_G(p->to, -1, k), k);
        }
    }
}

int main() {
    scanf("%d%d", &N, &M), SQ = 8*(int)sqrt(M);
    for(int i=1;i<N;i++) {
        int u, v; scanf("%d%d", &u, &v);
        addedge(u - 1, v - 1);
    }
    scanf("%s%s", s, S1.str);
    for(int i=0;i<M;i++) S2.str[i] = S1.str[M-i-1];
    S1.build(M), S2.build(M), divide(get_G(0, -1, get_size(0, -1)), N);
    printf("%lld\n", ans);
}

@details@

可以把分界点的大小适当调大(显然后一种算法常数更大)。

一开始 T 了还以为是常数问题,结果仔细一看发现我点分治只有第一轮找了重心,后面没有找重心就直接递归了。。。

posted @ 2020-01-21 09:32  Tiw_Air_OAO  阅读(...)  评论(...编辑  收藏