【BZOJ 4503】4503: 两个串 (FFT)

4503: 两个串

Time Limit: 20 Sec  Memory Limit: 256 MB
Submit: 497  Solved: 226

Description

兔子们在玩两个串的游戏。给定两个字符串S和T,兔子们想知道T在S中出现了几次,
分别在哪些位置出现。注意T中可能有“?”字符,这个字符可以匹配任何字符。

Input

两行两个字符串,分别代表S和T

Output

第一行一个正整数k,表示T在S中出现了几次
接下来k行正整数,分别代表T每次在S中出现的开始位置。按照从小到大的顺序输出,S下标从0开始。

Sample Input

bbabaababaaaaabaaaaaaaabaaabbbabaaabbabaabbbbabbbbbbabbaabbbababababbbbbbaaabaaabbbbbaabbbaabbbbabab
a?aba?abba

Sample Output

0

HINT

S 长度不超过 10^5, T 长度不会超过 S。 S 中只包含小写字母, T中只包含小写字母和“?”


Source

 

 

【分析】

  这题做了我好久啊。。幸好1A,哭死了。。

  首先,是不可以用KMP的,【我一开始还觉得可以】

  因为你根据第二个串求next可能会以'?'为媒介,认为'a'与'b'相同什么的。

  小山羊说,大多数含'?'的字符串的题目都不可以用KMP,EXKMP,AC自动机,后缀数组等等字符串匹配方法,就会出现上面说的问题。不过我觉得如果这题的'?'是在第一个串上的话,应该是可以用的。

  所以怎么办呢?FFT大法。

  很容易想到一个方法就是类似‘万径人踪灭’那里的。做26次FFT,每次把一个字母标为1,其他标为0,看一看匹配长度是否为m,但是会超时?【并没有实测过

  不过这题只是问你能否完全匹配,并没有问你匹配多少个字符之类的。

  所以可以构造这么一个函数。设$a[i]=s[i]=='?'?0:a[i]-'a'+1$。

  则$ans[i+j]=(a[i]-b[j])^{2}*b[j]$

  当且仅当全部匹配,ans为0,后面乘a[j]是搞'?'的。

  这个也很显然吧。

  但是!!!注意这个式子啊,不能把a,b化成点值表示就直接乘,否则意义就会变成

  $ans[i+i+j]+=a[i]^{2}*a[j]; ans[j+j+j]+=a[j]^{3}; ans[i+j+j]+=-2*a[i]*a[j]^{2}$意义完全不对。

  我还是理解了好久才发现这个错啊,以后不能再错这个了!!!

  所以你要分几次搞,令$A[i]=a[i]^2$,$B[i]=b[i]^2$,$C[i]=1$,$D[i]=b[i]^{3}$

  则$ans=A*b+C*D-2*a*B$

  FFT加速即可【为什么我跑得那么慢并且代码那么丑?

 

 

 1 #include<cstdio>
 2 #include<cstdlib>
 3 #include<cstring>
 4 #include<iostream>
 5 #include<algorithm>
 6 #include<cmath>
 7 using namespace std;
 8 #define Maxn 100010*4
 9 const double pi=acos(-1);
10 
11 char s1[Maxn],s2[Maxn];
12 int aa[Maxn],bb[Maxn];
13 
14 struct P
15 {
16     double x,y;
17     P() {x=y=0;}
18     P(double x,double y):x(x),y(y){}
19     friend P operator + (P x,P y) {return P(x.x+y.x,x.y+y.y);}
20     friend P operator - (P x,P y) {return P(x.x-y.x,x.y-y.y);}
21     friend P operator * (P x,P y) {return P(x.x*y.x-x.y*y.y,x.x*y.y+x.y*y.x);}
22     friend P operator * (P x,int y) {return P(x.x*y,x.y*y);}
23 }a[Maxn],b[Maxn];
24 
25 int nn,R[Maxn],op[Maxn],ans[Maxn];
26 void fft(P *s,int f)
27 {
28     for(int i=0;i<nn;i++) if(i<R[i]) swap(s[i],s[R[i]]);
29     for(int i=1;i<nn;i<<=1)
30     {
31         P wn(cos(pi/i),f*sin(pi/i));
32         for(int j=0;j<nn;j+=i<<1)
33         {
34             P w(1,0);
35             for(int k=0;k<i;k++,w=w*wn)
36             {
37                 P x=s[j+k],y=w*s[j+k+i];
38                 s[j+k]=x+y;s[j+k+i]=x-y;
39             }
40         }
41     }
42     if(f==-1)
43     {
44         for(int i=0;i<=nn;i++) s[i].x=s[i].x/nn;
45     }
46 }
47 
48 int main()
49 {
50     scanf("%s%s",s1,s2);
51     int n=strlen(s1),m=strlen(s2);
52     n--;m--;
53     for(int i=0;i<=n;i++) aa[i]=s1[i]=='?'?0:(s1[i]-'a'+1);
54     for(int i=0;i<=m;i++) bb[m-i]=s2[i]=='?'?0:(s2[i]-'a'+1);
55     int ll=0;nn=1;
56     while(nn<=n+m) ll++,nn<<=1;
57     for(int i=0;i<nn;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(ll-1));
58     
59     for(int i=0;i<=n;i++) a[i].x=aa[i]*aa[i];
60     for(int i=0;i<=m;i++) b[i].x=bb[i];
61     fft(a,1);fft(b,1);
62     for(int i=0;i<=nn;i++) a[i]=a[i]*b[i];
63     fft(a,-1);
64     memset(ans,0,sizeof(ans));
65     for(int i=0;i<=n+m;i++) ans[i]+=(int)(a[i].x+0.5);
66     // for(int i=0;i<=n+m;i++) printf("%d ",ans[i]);
67     
68     for(int i=0;i<=nn;i++) a[i].x=a[i].y=b[i].x=b[i].y=0;
69     for(int i=0;i<=n;i++) a[i].x=1;
70     for(int i=0;i<=m;i++) b[i].x=bb[i]*bb[i]*bb[i];
71     fft(a,1);fft(b,1);
72     for(int i=0;i<=nn;i++) a[i]=a[i]*b[i];
73     fft(a,-1);
74     for(int i=0;i<=n+m;i++) ans[i]+=(int)(a[i].x+0.5);
75     // for(int i=0;i<=n+m;i++) printf("%d ",ans[i]);
76     
77     for(int i=0;i<=nn;i++) a[i].x=a[i].y=b[i].x=b[i].y=0;
78     for(int i=0;i<=n;i++) a[i].x=2*aa[i];
79     for(int i=0;i<=m;i++) b[i].x=bb[i]*bb[i];
80     fft(a,1);fft(b,1);
81     for(int i=0;i<=nn;i++) a[i]=a[i]*b[i];
82     fft(a,-1);
83     for(int i=0;i<=n+m;i++) ans[i]-=(int)(a[i].x+0.5);
84     // for(int i=0;i<=n+m;i++) printf("%d ",ans[i]);printf("\n");
85     
86     
87     op[0]=0;
88     for(int i=m;i<=n;i++) if(!ans[i])
89     {
90         op[++op[0]]=i-m;
91     }
92     printf("%d\n",op[0]);
93     for(int i=1;i<=op[0];i++) printf("%d\n",op[i]);
94     return 0;
95 }
View Code

 

 

2017-04-14 2017-04-14

posted @ 2017-04-14 09:46  konjak魔芋  阅读(244)  评论(0编辑  收藏  举报