【模板】快速数论变换

依旧是常数很大的板子。

H_Kaguya 改动之前,达到了 \(2.61\,s\) 的绝望时间

现在好多了,\(1.10\,s\)。(内存不连续访问我会记你一辈子的)

#include <iostream>
char ch;
short get_single() {
	ch = getchar();
	while(ch < '0') 
		ch = getchar();
	return ch&15;
}
const int mod = 998244353;
const int g = 3;
const int g_inv = 332748118;
const int N = 2100010;
long long quick_pow(long long _a,int _n,int _p = mod) {
	long long _res = 1;
	while(_n) {
		if(_n&1) 
			_res = _res*_a%_p;
		_a = _a*_a%_p;
		_n >>= 1;
	}
	return _res;
}
int rev[N];
int a[N], b[N];
int lim;
void ntt_init() {
	int tmp = lim>>1;
	for(int i = 1;i < lim;++i) {
		rev[i] = rev[i>>1]>>1;
		if(i&1) 
			rev[i] |= tmp;
	}
}
void NTT(int *f,int opt) {
	for(int i = 0;i < lim;++i) 
		if(i < rev[i]) 
			std :: swap(f[i],f[rev[i]]);
	long long w_n, w;
	for(register int i = 1, step = 2;i < lim;i <<= 1, step <<= 1) {
		w_n = quick_pow(~opt ? g : g_inv,(mod-1)/step);
		for(register int j = 0, upd = i;j < lim;j += step, upd += step) {
			w = 1;
			for(register int k = j, l = i+j;k < upd;++k, ++l, w = w*w_n%mod) {
				int y = w*f[l]%mod;
				f[l] = f[k]-y;
				if(f[l] < 0) 
					f[l] += mod;
				f[k] += y;
				if(f[k] >= mod) 
					f[k] -= mod;
			}
		}
	}
	if(!~opt) {
		long long inv_lim = quick_pow(lim,mod-2);
		for(int i = 0;i < lim;++i) 
			f[i] = f[i]*inv_lim%mod;
	}
}
int n, m;
int main() {
	scanf("%d %d",&n,&m);
	for(int i = 0;i <= n;++i) 
		a[i] = get_single();
	for(int i = 0;i <= m;++i) 
		b[i] = get_single();
	n += m;
	lim = 1<<31-__builtin_clz(n);
	if(lim&n) 
		lim <<= 1;
	ntt_init();
	NTT(a,1);
	NTT(b,1);
	for(int i = 0;i < lim;++i) 
		a[i] = 1ll*a[i]*b[i]%mod;
	NTT(a,-1);
	for(int i = 0;i <= n;++i) 
		printf("%d ",a[i]);
	return 0;
}

虽然中间采用了一些不通配的代码,但是在你谷已经可以稳定 \(1.0s\) 了。

#include <iostream>
char ch;
short get_single() {
	ch = getchar();
	while(ch < '0') 
		ch = getchar();
	return ch&15;
}
const int mod = 998244353;
const int N = 2100011;
int quick_pow(int _a,int _n,int _p = mod) {
	int _res = 1;
	while(_n) {
		if(_n&1) 
			_res = (long long)_res*_a%_p;
		_a = (long long)_a*_a%_p;
		_n >>= 1;
	}
	return _res;
}
int rev[N];
int a[N], b[N];
int lim;
void ntt_init() {
	int tmp = lim>>1;
	for(int i = 1;i < lim;++i) {
		rev[i] = rev[i>>1]>>1;
		if(i&1) 
			rev[i] |= tmp;
	}
}
void INTT() {
	for(int i = 0;i < lim;++i) 
		if(i < rev[i]) 
			std :: swap(a[i],a[rev[i]]);
	long long w_n, w;
	int i, j, k, step, upd, l, y, pc;
	for(i = 1, step = 2, pc = mod-1>>1;i < lim;pc >>= 1, i <<= 1, step <<= 1) {
		w_n = quick_pow(332748118,pc);
		for(j = 0, upd = i;j < lim;j += step, upd += step) {
			w = 1;
			for(k = j, l = i+j;k < upd;++k, ++l, w = w*w_n%mod) {
				y = w*a[l]%mod;
				a[l] = a[k]-y;
				if(a[l] < 0) 
					a[l] += mod;
				a[k] += y;
				if(a[k] >= mod) 
					a[k] -= mod;
			}
		}
	}
	w_n = quick_pow(lim,mod-2);
	for(int i = 0;i < lim;++i) 
		a[i] = a[i]*w_n%mod;
}
void NTT() {
	for(int i = 0;i < lim;++i) 
		if(i < rev[i]) {
			std :: swap(a[i],a[rev[i]]);
			std :: swap(b[i],b[rev[i]]);
		}
	long long w_n, w;
	int i, j, k, step, upd, l, y, pc;
	for(i = 1, step = 2, pc = mod-1>>1;i < lim;pc >>= 1, i <<= 1, step <<= 1) {
		w_n = quick_pow(3,pc);
		for(j = 0, upd = i;j < lim;j += step, upd += step) {
			for(w = 1, k = j, l = i+j;k < upd;++k, ++l, w = w*w_n%mod) {
				y = w*a[l]%mod;
				a[l] = a[k]-y;
				if(a[l] < 0) 
					a[l] += mod;
				a[k] += y;
				if(a[k] >= mod) 
					a[k] -= mod;
				y = w*b[l]%mod;
				b[l] = b[k]-y;
				if(b[l] < 0) 
					b[l] += mod;
				b[k] += y;
				if(b[k] >= mod) 
					b[k] -= mod;
			}
		}
	}
}
int n, m;
int main() {
	#ifndef ONLINE_JUDGE
	freopen("P3803_8.in","r",stdin);
	freopen("test.out","w",stdout);
	#endif 
	scanf("%d %d",&n,&m);
	for(int i = 0;i <= n;++i) 
		a[i] = get_single();
	for(int i = 0;i <= m;++i) 
		b[i] = get_single();
	n += m;
	lim = 1<<31-__builtin_clz(n);
	if(lim&n) 
		lim <<= 1;
	ntt_init();
	NTT();
	for(int i = 0;i < lim;++i) 
		a[i] = (long long)a[i]*b[i]%mod;
	INTT();
	for(int i = 0;i <= n;++i) 
		printf("%d ",a[i]);
	return 0;
}

怎么做到干过 FFT 的?

“你是怎么做到的?”

“请帮帮我!”

#include <iostream>
#include <algorithm>
void get_single(int &x) {
	char temp = getchar();
	while(!isdigit(temp)) 
		temp = getchar();
	x = temp&15;
}
constexpr int mod = 998244353;
constexpr int N = 2100010;
int n, m, len, pwr, a[N], b[N], w[N];
int quick_pow(int a,int n,int p = mod) {
	int res = 1;
	while(n) {
		if(n&1) res = 1ll*res*a%p;
		a = 1ll*a*a%p;
		n >>= 1;
	}
	return res;
}
void init() {
	while((1<<pwr) < n+m) 
		++pwr;
	pwr = std :: min(pwr-1,21);
	w[0] = 1;
	w[1<<pwr] = quick_pow(31,1<<21-pwr);
	for(int i = pwr;i;--i) 
		w[1<<i-1] = 1ll*w[1<<i]*w[1<<i]%mod;
	for(int i = 1;i < (1<<pwr);++i) 
		w[i] = 1ll*w[i&(i-1)]*w[i&-i]%mod;
}
void NTT(int *ary,int len) {
	for(int mid = len>>1;mid;mid >>= 1) 
		for(int i = 0, k = 0;i < len;i += mid<<1, ++k) 
			for(int j = 0;j < mid;++j) {
				int x = 1ll*ary[i+j+mid]*w[k]%mod;
				ary[i+j+mid] = (ary[i+j] < x ? ary[i+j]-x+mod : ary[i+j]-x);
				ary[i+j] = (ary[i+j]+x >= mod ? ary[i+j]+x-mod : ary[i+j]+x);
			}
}
void INTT(int *ary,int len) {
	for(int mid = 1;mid < len;mid <<= 1) 
		for(int i = 0, k = 0;i < len;i += mid<<1, ++k) 
			for(int j = 0;j < mid;++j) {
				int x = ary[i+j+mid];
				ary[i+j+mid] = 1ll*(ary[i+j] < x ? ary[i+j]-x+mod : ary[i+j]-x)*w[k]%mod;
				ary[i+j] = (ary[i+j]+x >= mod ? ary[i+j]+x-mod : ary[i+j]+x);
			}
	int inv = quick_pow(len,mod-2);
	for(int i = 0;i < len;++i) 
		ary[i] = 1ll*ary[i]*inv%mod;
	std :: reverse(a+1,a+len);
}
void mul() {
	len = 1;
	while(len < n+m-1) 
		len <<= 1;
	init();
	NTT(a,len);
	NTT(b,len);
	for(int i = 0;i < len;++i) 
		a[i] = 1ll*a[i]*b[i]%mod;
	INTT(a,len);
}
int main() {
	scanf("%d%d",&n,&m);
	++n, ++m;
	for(int i = 0;i < n;++i) 
		get_single(a[i]);
	for(int i = 0;i < m;++i) 
		get_single(b[i]);
	mul();
	for(int i = 0;i < n+m-1;++i) 
		printf("%d ",a[i]);
	return 0;
}
posted @ 2023-02-20 09:49  bikuhiku  阅读(79)  评论(13编辑  收藏  举报