【fft】洛谷 P4173 残缺的字符串
这份代码O2优化才能过
#include <bits/stdc++.h>
using namespace std;
#define eps 1e-6
#define MS 300009
#define LL long long
// 洛谷 P3803 【模板】多项式乘法(FFT)
// MS = 1e6
// time: max( 800ms )
// ========================================================================================
typedef complex<double> comp;
const double PI = acos(-1);
const int N = (1<<20)+10; // 长度为原长度向上的2^n, 再乘 2
int lim, r[N];
comp a[N], b[N];
// 清空
void clear(){
memset(a,0,sizeof a);
memset(b,0,sizeof b);
}
// 如果n,m长度不变, 且多次求, 跑一遍 get_lim_r 就行
void get_lim_r(int n, int m){
int l = 0;
for(lim = 1; lim <= n + m; lim <<= 1) ++ l;
for(int i = 0; i < lim; i ++)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
}
//
void fft(comp * a, int type) {
for(int i = 0; i < lim; i ++)
if(i < r[i]) swap(a[i], a[r[i]]);
for(int i = 1; i < lim; i <<= 1) {
comp x(cos(PI / i), type * sin(PI / i));
for(int j = 0; j < lim; j += (i << 1)) {
comp y(1, 0);
for(int k = 0; k < i; k ++, y *= x) {
comp p = a[j + k], q = y * a[j + k + i];
a[j + k] = p + q; a[j + k + i] = p - q;
}
}
}
}
// 多项式 c1 [0,n]; 多项式 c2 [0,m]; 返回结果 c
void get(double *c1, int n, double *c2, int m, double *c){
// 如果n,m长度不变, 且多次求, 跑一遍 get_lim_r 就行
// get_lim_r(n, m);
for(int i = 0; i <= n; i ++) a[i] = c1[i];
for(int i = 0; i <= m; i ++) b[i] = c2[i];
fft(a, 1), fft(b, 1);
for(int i = 0; i <= lim; i ++) a[i] *= b[i];
fft(a, -1);
for(int i = 0; i <= n + m; i ++) c[i] = (LL)(0.5 + a[i].real() / lim);
}
// ========================================================================================
// P(x) = ∑S(i)^3*B(j) + ∑S(i)*B(j)^3 - 2*∑S(i)^2*B(j)^2
int n,m;
char s1[MS], s2[MS];
double p1[MS], p2[MS], p3[MS];
double A[MS], B[MS], S[MS];
double f[3][MS<<1];
int ac[MS], tot;
int main() {
// ios::sync_with_stdio(false);
scanf("%d %d",&n,&m);
scanf("%s",s1); getchar();
scanf("%s",s2); getchar();
for(int i=0;i<n;i++) p1[i] = s1[i]=='*'? 0:s1[i]-'a'+1;
for(int i=0;i<m;i++) p2[i] = s2[i]=='*'? 0:s2[i]-'a'+1;
for(int i=0;i<n;i++) p3[i] = p1[n-i-1];
get_lim_r(n-1, m-1);
for(int i=0;i<n;i++) S[i] = p3[i]*p3[i]*p3[i];
for(int i=0;i<m;i++) B[i] = p2[i];
clear(); get(S, n-1, B, m-1, f[0]);
for(int i=0;i<n;i++) S[i] = p3[i];
for(int i=0;i<m;i++) B[i] = p2[i]*p2[i]*p2[i];
clear(); get(S, n-1, B, m-1, f[1]);
for(int i=0;i<n;i++) S[i] = p3[i]*p3[i];
for(int i=0;i<m;i++) B[i] = p2[i]*p2[i];
clear(); get(S, n-1, B, m-1, f[2]);
for(int i=n-1;i<m;i++){
double P = f[0][i] + f[1][i] - f[2][i]*2;
if(abs(P) < eps) ac[++tot] = i-n+2;
}
printf("%d\n",tot);
for(int i=1;i<=tot;i++) printf("%d ",ac[i]);
return 0;
}

浙公网安备 33010602011771号