BZOJ 4503: 两个串

Description

有通配符的字符串匹配.\(n,m\leqslant 10^5\)

Solution

FFT.

\(D_k=\sum_{i+j=k}(S_i-T_j)^2T_j\)

把他化成这样的式子,这样如果两个位置相等,或者\(T_j\)为\(0\),那么就可以匹配

把通配符设成\(0\)即可

Code

#include <bits/stdc++.h>
using namespace std;

#define mpr make_pair
#define x first
#define y second

const int N = 4e5+500;
const double Pi = M_PI;

namespace Pol {
	typedef pair<double,double> cp;
	cp operator + (const cp &a,const cp &b) { return mpr(a.x+b.x,a.y+b.y); }
	cp operator - (const cp &a,const cp &b) { return mpr(a.x-b.x,a.y-b.y); }
	cp operator * (const cp &a,const cp &b) { return mpr(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x); }
	
	int pn;
	void init(int n) { for(pn=1;pn<n;pn<<=1);pn<<=1; }
	void Rev(cp a[],int n=pn) {
		for(int i=0,j=0;i<n;i++) {
			if(i>j) swap(a[i],a[j]);
			for(int k=n>>1;(j^=k)<k;k>>=1);
		}
	}
	void DFT(cp a[],int r=1,int n=pn) {
		Rev(a);
		for(int i=2;i<=n;i<<=1) {
			cp wi=mpr(cos(2*Pi/i),r*sin(2*Pi/i));
			for(int j=0;j<n;j+=i) {
				cp w=mpr(1,0);
				for(int k=j;k<j+i/2;k++) {
					cp t1=a[k],t2=w*a[k+i/2];
					a[k]=t1+t2,a[k+i/2]=t1-t2;
					w=w*wi;
				}
			}
		}if(!~r) for(int i=0;i<n;i++) a[i].x/=n;
	}
	void FFT(cp a[],cp b[],cp c[],int n=pn) {
		DFT(a,1,n),DFT(b,1,n);
		for(int i=0;i<n;i++) c[i]=a[i]*b[i];
		DFT(c,-1,n);
	}
}

inline int in(int x=0,char ch=getchar()) { while(ch>'9'||ch<'0') ch=getchar();
	while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();return x; }

using namespace Pol;
int n,m;
char s[N],t[N];
double a[N],b[N];
cp s0[N],s1[N],s2[N],t1[N],t2[N],t3[N];
cp s2t[N],st2[N],s0t3[N];

int main() {
	scanf("%s%s",s,t);
	n=strlen(s),m=strlen(t);
	reverse(t,t+m);
	init(max(n,m));
	for(int i=0;i<n;i++) a[i]=s[i]-'a'+1;
	for(int i=0;i<m;i++) b[i]=t[i]=='?'?0:t[i]-'a'+1;
	
	for(int i=0;i<pn;i++) 
		s0[i].x=1,s1[i].x=a[i],s2[i].x=a[i]*a[i],
		t1[i].x=b[i],t2[i].x=b[i]*b[i],t3[i].x=b[i]*b[i]*b[i];
	
	DFT(s0,1),DFT(s1,1),DFT(s2,1),DFT(t1,1),DFT(t2,1),DFT(t3,1);
	for(int i=0;i<pn;i++) s2t[i]=s2[i]*t1[i],st2[i]=s1[i]*t2[i],s0t3[i]=s0[i]*t3[i];
	DFT(s2t,-1),DFT(st2,-1),DFT(s0t3,-1);
	
//	for(int i=0;i<pn;i++) cout<<(s2t[i].x-2*st2[i].x+s0t3[i].x)<<" ";cout<<endl;
	int ans=0;
	for(int i=m-1;i<n;i++) if((int)(s2t[i].x-2*st2[i].x+s0t3[i].x+0.5)==0) ans++;
	printf("%d\n",ans);
	for(int i=m-1;i<n;i++) if((int)(s2t[i].x-2*st2[i].x+s0t3[i].x+0.5)==0) printf("%d ",i-m+1);
	return 0;
}

  

posted @ 2017-04-17 14:45  北北北北屿  阅读(77)  评论(0编辑  收藏  举报