BZOJ4598 [Sdoi2016]模式字符串 【点分治 + hash】

题目

给出n个结点的树结构T,其中每一个结点上有一个字符,这里我们所说的字符只考虑大写字母A到Z,再给出长度为m
的模式串s,其中每一位仍然是A到z的大写字母。Alice希望知道,有多少对结点<u,v>满足T上从u到V的最短路径
形成的字符串可以由模式串S重复若干次得到?这里结点对<u,v>是有序的,也就是说<u,v>和<v,u>需要被区分.
所谓模式串的重复,是将若干个模式串S依次相接(不能重叠).例如当S=PLUS的时候,重复两次会得到PLUSPLUS,
重复三次会得到PLUSPLUSPLUS,同时要注恿,重复必须是整数次的。例如当S=XYXY时,因为必须重复整数次,所以X
YXYXY不能看作是S重复若干次得到的。

输入格式

每一个数据有多组测试,
第一行输入一个整数C,表示总的测试个数。
对于每一组测试来说:
第一行输入两个整数,分别表示树T的结点个数n与模式长度m。结点被依次编号为1到n,
之后一行,依次给出了n个大写字母(以一个长度为n的字符串的形式给出),依次对应树上每一个结点上的字符(
第i个字符对应了第i个结点).
之后n-1行,每行有两个整数u和v表示树上的一条无向边,之后一行给定一个长度为m的由大写字母组成的字符串,
为模式串S。
1<=C<=10,3<=N<=10000003<=M<=1000000

输出格式

给出C行,对应C组测试。每一行输出一个整数,表示有多少对节点<u,v>满足从u到v的路径形成的字符串恰好是模
式串的若干次重复.

输入样例

1

11 4

IODSSDSOIOI

1 2

2 3

3 4

1 5

5 6

6 7

3 8

8 9

6 10

10 11

SDOI

输出样例

5

提示

数据文件太过巨大,仅提供前三组数据测试.

题解

BZOJ数据较小,卡过了
但洛谷似乎T得不行

我们预处理出字符串前i个和后i个的hash值【这里\(i<=n\)处理的字符串由原字符串复制多次形成】
然后点分
对于每棵子树,进行遍历,记录当前到根的hash值,如果匹配上了前缀或者后缀,查找f[i]或者g[i]表示长度对m取模后为i的到根路径为原字符串前缀或后缀的路径数,更新答案

常熟略大,,弱弱卡过

#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#define LL long long int
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define ULL unsigned long long int
#define cls(s) memset(s,0,sizeof(s))
#define BUG(s,n) for (int i = 1; i <= (n); i++) cout<<s[i]<<' '; puts("");
using namespace std;
const int maxn = 1000005,maxm = 2000005,INF = 1000000000;
inline int read(){
    int out = 0,flag = 1; char c = getchar();
    while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
    while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
    return out * flag;
}
ULL Hl[maxn],Hr[maxn];
char s[maxn],val[maxn];
int n,m;
int h[maxn],ne = 2;
int F[maxn],Siz[maxn],fa[maxn],vis[maxn],sum,rt;
LL ans;
struct EDGE{int to,nxt;}ed[maxm];
void build(int u,int v){
    ed[ne] = (EDGE){v,h[u]}; h[u] = ne++;
    ed[ne] = (EDGE){u,h[v]}; h[v] = ne++;
}
void init(){
    for (int i = 1; i <= n; i++) vis[i] = h[i] = fa[i] = 0;
    ne = 2; ans = 0;
}
void getrt(int u){
    Siz[u] = 1; F[u] = 0;
    Redge(u) if (!vis[to = ed[k].to] && to != fa[u]){
        fa[to] = u; getrt(to);
        Siz[u] += Siz[to];
        F[u] = max(F[u],Siz[to]);
    }
    F[u] = max(F[u],sum - Siz[u]);
    if (F[u] < F[rt]) rt = u;
}
int pre[maxn],post[maxn],dep[maxn];
ULL V[maxn],P[maxn];
void DFS(int u){
    Siz[u] = 1;
    Redge(u) if (!vis[to = ed[k].to] && to != fa[u]){
        fa[to] = u; DFS(to);
        Siz[u] += Siz[to];
    }
}
void dfs1(int u){
    V[u] = V[fa[u]] * 107 + val[u];
    int d = (dep[u] - 1) % m + 1;
    if (V[u] == Hl[dep[u]] && s[d % m + 1] == val[rt]){
        //printf("find at %d\n",u);
        ans += post[((m - d - 1) % m + m) % m];
    }
    if (V[u] == Hr[dep[u]] && s[m - d % m] == val[rt]){
        //printf("rfind at %d\n",u);
        ans += pre[((m - d - 1) % m + m) % m];
    }
    Redge(u) if (!vis[to = ed[k].to] && to != fa[u]){
        fa[to] = u; dep[to] = dep[u] + 1;
        dfs1(to);
    }
}
void dfs2(int u){
    int d = dep[u] % m;
    if (V[u] == Hr[dep[u]]) post[d]++;
    if (V[u] == Hl[dep[u]]) pre[d]++;
    Redge(u) if (!vis[to = ed[k].to] && to != fa[u]){
        fa[to] = u; dep[to] = dep[u] + 1;
        dfs2(to);
    }
}
void solve(int u){
    vis[u] = true;
    fa[u] = 0; DFS(u);
    if (Siz[u] < m) return;
    for (int i = min(Siz[u],m); i >= 0; i--) pre[i] = post[i] = 0;
    pre[0] = post[0] = 1;
    V[u] = 0;
    Redge(u) if (!vis[to = ed[k].to]){
        dep[to] = 1; fa[to] = u; dfs1(to);
        dep[to] = 1; fa[to] = u; dfs2(to);
    }
    Redge(u) if (!vis[to = ed[k].to]){
        sum = Siz[to]; F[rt = 0] = INF;
        getrt(to); solve(rt);
    }
}
int main(){
    P[0] = 1;
    for (int i = 1; i <= 1000000; i++) P[i] = P[i - 1] * 107;
    int T = read();
    while (T--){
        init();
        n = read(); m = read();
        scanf("%s",s + 1);
        for (int i = 1; i <= n; i++) val[i] = s[i];
        for (int i = 1; i < n; i++) build(read(),read());
        scanf("%s",s + 1);
        for (int i = 1; i <= n; i++)
            Hl[i] = Hl[i - 1] + P[i - 1] * s[(i - 1) % m + 1];
        for (int i = 1; i <= n; i++)
            Hr[i] = Hr[i - 1] + P[i - 1] * s[m - (i - 1) % m];
        F[rt = 0] = INF; sum = n;
        getrt(1); solve(rt);
        printf("%lld\n",ans);
    }
    return 0;
}

posted @ 2018-03-06 13:17  Mychael  阅读(442)  评论(0编辑  收藏  举报