[LOJ2065] [SDOI2016]模式字符串

题目链接

洛谷:https://www.luogu.org/problemnew/show/P4075

LOJ:https://loj.ac/problem/2065

Solution

这种题看起来就很点分治啊...

我们可以发现,我们需要一个支持询问字符串相等,并且支持在一个串前面加一个串的数据结构,显然我们用哈希就行了。

那么我们直接开桶然后拿哈希维护,总复杂度\(O(Tn\log n)\)

#include<bits/stdc++.h>
using namespace std;

template<typename T> void read(T &x) {
    x=0;T ff=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') ff=-ff;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=ff;
}

void print(int x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}

#define lf double
#define ll long long 

#define pii pair<int,int >
#define vec vector<int >

#define pb push_back
#define mp make_pair
#define fr first
#define sc second

#define FOR(i,l,r) for(int i=l,i##_r=r;i<=i##_r;i++) 

const int maxn = 1e6+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int bs = 29;
const int mod = 1e6+3;

ll ans;
char s[maxn];
int sz[maxn],rs[maxn],rp[maxn],rt,f[maxn],SZ,mxd,rrs[maxn],rrp[maxn];
int n,m,r[maxn],tag[maxn],head[maxn],tot,len,suf[maxn],pre[maxn],pw[maxn],vis[maxn];
struct edge{int to,nxt;}e[maxn<<1];

void add(int u,int v) {e[++tot]=(edge){v,head[u]},head[u]=tot;}
void ins(int u,int v) {add(u,v),add(v,u);}

void prepare() {
    for(len=m;len<n;len+=m)
        for(int i=len+1;i<=len+m;i++) tag[i]=tag[i-len];
    for(int i=1;i<=len;i++) pre[i]=(pre[i-1]*bs+tag[i])%mod;
    suf[len+1]=0;for(int i=len;i;i--) suf[i]=(suf[i+1]*bs+tag[i])%mod;
    reverse(suf+1,suf+len+1);
    pw[0]=1;for(int i=1;i<=len;i++) pw[i]=pw[i-1]*bs%mod;
}

void get_rt(int x,int fa) {
    sz[x]=1,f[x]=0;
    for(int i=head[x];i;i=e[i].nxt)
        if(e[i].to!=fa&&!vis[e[i].to]) get_rt(e[i].to,x),sz[x]+=sz[e[i].to],f[x]=max(f[x],sz[e[i].to]);
    f[x]=max(f[x],SZ-sz[x]);if(f[rt]>f[x]) rt=x;
}

void calc(int x,int fa,int dep,int hs,const int &c) {
    hs=(hs+1ll*r[x]*pw[dep])%mod;dep++;mxd=max(mxd,dep);
    if(hs==pre[dep]) {
        rrp[dep%m]++;
        if(c==tag[dep%m+1]) ans+=rs[m-dep%m-1];
    }
    if(hs==suf[dep]) {
        rrs[dep%m]++;
        if(c==tag[m-dep%m]) ans+=rp[m-dep%m-1];
    }
    for(int i=head[x];i;i=e[i].nxt) if(e[i].to!=fa&&!vis[e[i].to]) calc(e[i].to,x,dep,hs,c);
}

void dfs(int x) {
    f[rt=0]=maxn;get_rt(x,0);x=rt;vis[x]=1;mxd=0;
    rp[0]=rs[0]=1;int mmxd=0;
    for(int i=head[x];i;i=e[i].nxt)
        if(!vis[e[i].to]) {
            mxd=0;calc(e[i].to,x,0,0,r[x]);mmxd=max(mmxd,mxd);
            for(int j=0;j<=min(mxd,m);j++) rp[j]+=rrp[j],rs[j]+=rrs[j],rrp[j]=rrs[j]=0;
        }
    for(int i=0;i<=mmxd;i++) rs[i]=rp[i]=0;
    for(int i=head[x];i;i=e[i].nxt)
        if(!vis[e[i].to]) SZ=sz[e[i].to],dfs(e[i].to);
}

void solve() {
    read(n),read(m);scanf("%s",s+1);ans=0;
    for(int i=1;i<=n;i++) r[i]=s[i]-'A'+1;
    for(int x,y,i=1;i<n;i++) read(x),read(y),ins(x,y);
    scanf("%s",s+1);
    for(int i=1;i<=m;i++) tag[i]=s[i]-'A'+1;
    prepare();SZ=n;dfs(1);write(ans);
}

#define clr(x) memset(x,0,(n+3)*4)

void clear() {clr(head),clr(vis);tot=0;}

int main() {
    int st=clock();
    int t;read(t);while(t--) solve(),clear();
    cerr << (double)(clock()-st)/1e3 << endl;
    return 0;
}
posted @ 2019-05-07 21:08  Hyscere  阅读(222)  评论(0编辑  收藏  举报