# 【BZOJ4259】残缺的字符串-FFT

$\left(a-b{\right)}^{2}ab$

$\sum \left({A}_{i}-{B}_{i}{\right)}^{2}{A}_{i}{B}_{i}$

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const double eps=1.0;
const double pi=acos(-1.0);
int m,n,r[1200010],bit,x;
char A[300010],B[300010];
double f[300010][3],g[300010][3],val[300010],ans[1200010]={0};
struct Complex
{
double x,y;
}a[1200010],b[1200010];
Complex operator + (Complex a,Complex b) {Complex s={a.x+b.x,a.y+b.y};return s;}
Complex operator - (Complex a,Complex b) {Complex s={a.x-b.x,a.y-b.y};return s;}
Complex operator * (Complex a,Complex b) {Complex s={a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};return s;}

void FFT(Complex *a,int n,int type)
{
for(int i=0;i<n;i++)
if (i<r[i]) swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1)
{
Complex W={cos(pi/mid),type*sin(pi/mid)};
for(int l=0;l<n;l+=(mid<<1))
{
Complex w={1.0,0.0};
for(int k=0;k<mid;k++,w=w*W)
{
Complex x=a[l+k],y=w*a[l+mid+k];
a[l+k]=x+y;
a[l+mid+k]=x-y;
}
}
}
if (type==-1)
{
for(int i=0;i<n;i++)
a[i].x/=(double)n;
}
}

void calc_f()
{
for(int mode=0;mode<3;mode++)
{
for(int x=0;x<m;x++)
{
if (A[x]=='*') {f[x][mode]=0.0;continue;}
if (mode==0) f[x][mode]=(double)(A[x]*A[x]*A[x]);
if (mode==1) f[x][mode]=(double)(A[x]*A[x]);
if (mode==2) f[x][mode]=(double)(A[x]);
}
}
}

void calc_g()
{
for(int mode=0;mode<3;mode++)
{
for(int x=0;x<n;x++)
{
if (B[x]=='*') {g[x][mode]=0.0;continue;}
if (mode==0) g[x][mode]=(double)(B[x]);
if (mode==1) g[x][mode]=(double)(B[x]*B[x]);
if (mode==2) g[x][mode]=(double)(B[x]*B[x]*B[x]);
}
}
}

double calc_val()
{
val[0]=val[2]=1.0;
val[1]=-2.0;
}

void solve(int mode)
{
memset(a,0,sizeof(a)),memset(b,0,sizeof(b));
for(int i=0;i<m;i++)
a[i].x=f[m-1-i][mode];
for(int i=0;i<n;i++)
b[i].x=g[i][mode];
FFT(a,x,1),FFT(b,x,1);
for(int i=0;i<x;i++)
a[i]=a[i]*b[i];
FFT(a,x,-1);
for(int i=0;i<x;i++)
ans[i]+=val[mode]*a[i].x;
}

int main()
{
scanf("%d%d",&m,&n);
scanf("%s",A);
scanf("%s",B);

bit=0,x=1;
while(x<n+m) x<<=1,bit++;
r[0]=0;
for(int i=1;i<=x;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(bit-1));

calc_f(),calc_g(),calc_val();
solve(0);
solve(1);
solve(2);

int cnt=0;
for(int i=0;i<=n-m;i++)
if (fabs(ans[i+m-1])<eps) cnt++;
printf("%d\n",cnt);
for(int i=0;i<=n-m;i++)
if (fabs(ans[i+m-1])<eps) printf("%d ",i+1);

return 0;
}
posted @ 2018-04-01 10:13  Maxwei_wzj  阅读(77)  评论(0编辑  收藏  举报