【模板】快速数论变换
依旧是常数很大的板子。
在 H_Kaguya
改动之前,达到了 \(2.61\,s\) 的绝望时间
现在好多了,\(1.10\,s\)。(内存不连续访问我会记你一辈子的)
#include <iostream>
char ch;
short get_single() {
ch = getchar();
while(ch < '0')
ch = getchar();
return ch&15;
}
const int mod = 998244353;
const int g = 3;
const int g_inv = 332748118;
const int N = 2100010;
long long quick_pow(long long _a,int _n,int _p = mod) {
long long _res = 1;
while(_n) {
if(_n&1)
_res = _res*_a%_p;
_a = _a*_a%_p;
_n >>= 1;
}
return _res;
}
int rev[N];
int a[N], b[N];
int lim;
void ntt_init() {
int tmp = lim>>1;
for(int i = 1;i < lim;++i) {
rev[i] = rev[i>>1]>>1;
if(i&1)
rev[i] |= tmp;
}
}
void NTT(int *f,int opt) {
for(int i = 0;i < lim;++i)
if(i < rev[i])
std :: swap(f[i],f[rev[i]]);
long long w_n, w;
for(register int i = 1, step = 2;i < lim;i <<= 1, step <<= 1) {
w_n = quick_pow(~opt ? g : g_inv,(mod-1)/step);
for(register int j = 0, upd = i;j < lim;j += step, upd += step) {
w = 1;
for(register int k = j, l = i+j;k < upd;++k, ++l, w = w*w_n%mod) {
int y = w*f[l]%mod;
f[l] = f[k]-y;
if(f[l] < 0)
f[l] += mod;
f[k] += y;
if(f[k] >= mod)
f[k] -= mod;
}
}
}
if(!~opt) {
long long inv_lim = quick_pow(lim,mod-2);
for(int i = 0;i < lim;++i)
f[i] = f[i]*inv_lim%mod;
}
}
int n, m;
int main() {
scanf("%d %d",&n,&m);
for(int i = 0;i <= n;++i)
a[i] = get_single();
for(int i = 0;i <= m;++i)
b[i] = get_single();
n += m;
lim = 1<<31-__builtin_clz(n);
if(lim&n)
lim <<= 1;
ntt_init();
NTT(a,1);
NTT(b,1);
for(int i = 0;i < lim;++i)
a[i] = 1ll*a[i]*b[i]%mod;
NTT(a,-1);
for(int i = 0;i <= n;++i)
printf("%d ",a[i]);
return 0;
}
虽然中间采用了一些不通配的代码,但是在你谷已经可以稳定 \(1.0s\) 了。
#include <iostream>
char ch;
short get_single() {
ch = getchar();
while(ch < '0')
ch = getchar();
return ch&15;
}
const int mod = 998244353;
const int N = 2100011;
int quick_pow(int _a,int _n,int _p = mod) {
int _res = 1;
while(_n) {
if(_n&1)
_res = (long long)_res*_a%_p;
_a = (long long)_a*_a%_p;
_n >>= 1;
}
return _res;
}
int rev[N];
int a[N], b[N];
int lim;
void ntt_init() {
int tmp = lim>>1;
for(int i = 1;i < lim;++i) {
rev[i] = rev[i>>1]>>1;
if(i&1)
rev[i] |= tmp;
}
}
void INTT() {
for(int i = 0;i < lim;++i)
if(i < rev[i])
std :: swap(a[i],a[rev[i]]);
long long w_n, w;
int i, j, k, step, upd, l, y, pc;
for(i = 1, step = 2, pc = mod-1>>1;i < lim;pc >>= 1, i <<= 1, step <<= 1) {
w_n = quick_pow(332748118,pc);
for(j = 0, upd = i;j < lim;j += step, upd += step) {
w = 1;
for(k = j, l = i+j;k < upd;++k, ++l, w = w*w_n%mod) {
y = w*a[l]%mod;
a[l] = a[k]-y;
if(a[l] < 0)
a[l] += mod;
a[k] += y;
if(a[k] >= mod)
a[k] -= mod;
}
}
}
w_n = quick_pow(lim,mod-2);
for(int i = 0;i < lim;++i)
a[i] = a[i]*w_n%mod;
}
void NTT() {
for(int i = 0;i < lim;++i)
if(i < rev[i]) {
std :: swap(a[i],a[rev[i]]);
std :: swap(b[i],b[rev[i]]);
}
long long w_n, w;
int i, j, k, step, upd, l, y, pc;
for(i = 1, step = 2, pc = mod-1>>1;i < lim;pc >>= 1, i <<= 1, step <<= 1) {
w_n = quick_pow(3,pc);
for(j = 0, upd = i;j < lim;j += step, upd += step) {
for(w = 1, k = j, l = i+j;k < upd;++k, ++l, w = w*w_n%mod) {
y = w*a[l]%mod;
a[l] = a[k]-y;
if(a[l] < 0)
a[l] += mod;
a[k] += y;
if(a[k] >= mod)
a[k] -= mod;
y = w*b[l]%mod;
b[l] = b[k]-y;
if(b[l] < 0)
b[l] += mod;
b[k] += y;
if(b[k] >= mod)
b[k] -= mod;
}
}
}
}
int n, m;
int main() {
#ifndef ONLINE_JUDGE
freopen("P3803_8.in","r",stdin);
freopen("test.out","w",stdout);
#endif
scanf("%d %d",&n,&m);
for(int i = 0;i <= n;++i)
a[i] = get_single();
for(int i = 0;i <= m;++i)
b[i] = get_single();
n += m;
lim = 1<<31-__builtin_clz(n);
if(lim&n)
lim <<= 1;
ntt_init();
NTT();
for(int i = 0;i < lim;++i)
a[i] = (long long)a[i]*b[i]%mod;
INTT();
for(int i = 0;i <= n;++i)
printf("%d ",a[i]);
return 0;
}
怎么做到干过 FFT 的?
“你是怎么做到的?”
“请帮帮我!”
#include <iostream>
#include <algorithm>
void get_single(int &x) {
char temp = getchar();
while(!isdigit(temp))
temp = getchar();
x = temp&15;
}
constexpr int mod = 998244353;
constexpr int N = 2100010;
int n, m, len, pwr, a[N], b[N], w[N];
int quick_pow(int a,int n,int p = mod) {
int res = 1;
while(n) {
if(n&1) res = 1ll*res*a%p;
a = 1ll*a*a%p;
n >>= 1;
}
return res;
}
void init() {
while((1<<pwr) < n+m)
++pwr;
pwr = std :: min(pwr-1,21);
w[0] = 1;
w[1<<pwr] = quick_pow(31,1<<21-pwr);
for(int i = pwr;i;--i)
w[1<<i-1] = 1ll*w[1<<i]*w[1<<i]%mod;
for(int i = 1;i < (1<<pwr);++i)
w[i] = 1ll*w[i&(i-1)]*w[i&-i]%mod;
}
void NTT(int *ary,int len) {
for(int mid = len>>1;mid;mid >>= 1)
for(int i = 0, k = 0;i < len;i += mid<<1, ++k)
for(int j = 0;j < mid;++j) {
int x = 1ll*ary[i+j+mid]*w[k]%mod;
ary[i+j+mid] = (ary[i+j] < x ? ary[i+j]-x+mod : ary[i+j]-x);
ary[i+j] = (ary[i+j]+x >= mod ? ary[i+j]+x-mod : ary[i+j]+x);
}
}
void INTT(int *ary,int len) {
for(int mid = 1;mid < len;mid <<= 1)
for(int i = 0, k = 0;i < len;i += mid<<1, ++k)
for(int j = 0;j < mid;++j) {
int x = ary[i+j+mid];
ary[i+j+mid] = 1ll*(ary[i+j] < x ? ary[i+j]-x+mod : ary[i+j]-x)*w[k]%mod;
ary[i+j] = (ary[i+j]+x >= mod ? ary[i+j]+x-mod : ary[i+j]+x);
}
int inv = quick_pow(len,mod-2);
for(int i = 0;i < len;++i)
ary[i] = 1ll*ary[i]*inv%mod;
std :: reverse(a+1,a+len);
}
void mul() {
len = 1;
while(len < n+m-1)
len <<= 1;
init();
NTT(a,len);
NTT(b,len);
for(int i = 0;i < len;++i)
a[i] = 1ll*a[i]*b[i]%mod;
INTT(a,len);
}
int main() {
scanf("%d%d",&n,&m);
++n, ++m;
for(int i = 0;i < n;++i)
get_single(a[i]);
for(int i = 0;i < m;++i)
get_single(b[i]);
mul();
for(int i = 0;i < n+m-1;++i)
printf("%d ",a[i]);
return 0;
}