题解:P4191 [CTSC2010] 性能优化

bluestein 真是这个世界上最没用的东西,让我调了 8h 才过,wa 了两版,吃爽了。

以及题解里唯一一篇用 blustein 且有代码的题解马蜂又丑陋,交上去会 mle。

题意:给出两个长为 \(n\) 多项式 \(f,g\) 和参数 \(c\),求 \(f\)\(g\)\(c\) 次循环卷积的循环卷积。

做法:

首先我们注意到一个事情,其实 dft 做的是一个循环卷积,因为里面的单位根满足 \(\omega_{n}^{i}\times\omega_{n}^j = \omega_{n}^{(i+j)\mod n}\),所以其实是循环卷积。

那么我们考虑做任意长度 dft,但是我们并不会做这个东西,我们只会做 \(2^n\) 长度dft,接下来就是 bluestein 算法做的事情。

我们思考 dft 其实在干什么,发现是在求 \(f(\omega_n^i) = \sum\limits_{j=0}^{n-1}\omega_n^{ij}\)。我们考虑转化一下这个 \(ij\),发现等于 \(\frac{(i+j)(i+j-1) - i(i-1)-j(j-1)}{2}\),并且这个式子中每一项都是被 \(2\) 整除的,这也是我们为什么不能转化成其他式子比如 \(i^2+j^2-(i-j)^2\) 的原因。

于是我们就可以把这个式子改为 \(\omega_n^{\frac{i(i-1)}2}f(\omega_n^i) = \sum\limits_{j=0}^{n-1}\omega_{n}^{\frac{-j(j-1)}{2}}\omega_n^{\frac{(i+j)(i+j-1)}{2}}\),这个式子是一个卷积的形式,你用 fft 去卷就可以了。

为了一个 dft 你还要再写一个 fft...这里因为模数是 \(n+1\) 所以你必须写 fft。


这就是这个算法的全部流程了,当我做到此时直接开始疯狂写代码,但是我要强调一坨实现上的问题。

  1. 记得多取模,记得多取模,记得多取模。

  2. 如果你 tle 了,那么可以预处理 \(\omega\),会快非常非常多。

  3. 然后你会发现你最后三个点 wa 掉了,因为在 \(n=5\times 10^5\) 的时候 fft 精度掉完了,所以需要优化,我们考虑将一个系数 \(f_i\) 表示为 \(f_i = f_{i,1}\sqrt{n+1}+f_{i,2}\),然后将 \(f_1,f_2\) 分别 fft 再乘起来最后计算贡献。

  4. 写完之后觉得非常爽终于写完了,然后交上去 mle 80pts,全部换成 double 和 int 依旧无果。这里我们就需要继续改良 fft 的空间。我们发现将 \(f,g\) 拆成 4 个多项式真是太浪费了,我们可以考虑类似 mtt,将 \(f_2\) 压到虚部,这样可以少开一个多项式。

给出删掉 1k 调试的代码:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e6 + 5;
const long double PI = acos(-1), eps = 1e-12;
long long gb, gi, mod, n, c;
struct Couple {
	double x, y;
	inline friend Couple operator+(Couple x, Couple y) {
		return Couple{x.x + y.x, x.y + y.y};
	}
	inline friend Couple operator-(Couple x, Couple y) {
		return Couple{x.x - y.x, x.y - y.y};
	}
	inline friend Couple operator*(Couple x, Couple y) {
		return Couple{x.x * y.x - x.y * y.y, x.y * y.x + x.x * y.y};
	}
};
vector<int> rev;
vector<Couple> w;
void prepare(int len) {
	w.resize(len + 1);
	rev.resize(len);
	for (int i = 1; i < len; i <<= 1)
		for (int j = 0; j < i; j++)
			w[len / i * j] = Couple{cos(PI / i * j), sin(PI / i * j)};
	for (int i = 0; i < len; i++) {
		rev[i] = rev[i >> 1] >> 1;
		if(i & 1)
			rev[i] = rev[i] | (len >> 1);
	}
}
struct Poly1 {
	vector<Couple> a;
	inline int size() {
		return a.size();
	}
	inline void resize(int N) {
		a.resize(N);
	}
	inline Couple& operator[](int x) {
		return a[x];
	}
	inline Couple get_Cp(int p, int f) {
		return Couple{w[p].x, f * w[p].y};
	}
	void FFT(int f) {
		int n = size();
		for (int i = 0; i < n; i++)
			if(i < rev[i])
				swap(a[i], a[rev[i]]);
		for (int h = 2; h <= n; h <<= 1) {
			for (int i = 0; i < n; i += h) {
				for (int j = i; j < i + h / 2; j++) {
					Couple a0 = a[j], a1 = a[j + h / 2] * get_Cp(n / (h >> 1) * (j - i), f);
					a[j] = (a0 + a1), a[j + h / 2] = (a0 - a1);
				}
			}
		}
		if(f == -1) {
			for (int i = 0; i < n; i++)
				a[i].x /= n, a[i].y /= n;
		}
	}
	void print() {
		for (int i = 0; i < a.size(); i++)
			cout << a[i].x << " ";
		cout << endl;
	}
} ;
Poly1 fb, g1, g2;
Poly1 operator*(Poly1 &f, Poly1 &g) {
		int len = 1, t = f.size() + g.size() - 1;
		while(len < t)
			len <<= 1;
		prepare(len), f.resize(len), g.resize(len);
		fb.resize(0), g1.resize(0), g2.resize(0);
		fb.resize(len), g1.resize(len), g2.resize(len);
		int mo = sqrt(mod);
		for (int i = 0; i < len; i++)
			fb[i].x = (int)(f[i].x) / mo, fb[i].y = (int)f[i].x % mo, 
			g1[i].x = (int)(g[i].x) / mo, g2[i].x = (int)g[i].x % mo;
		fb.FFT(1), g1.FFT(1), g2.FFT(1);
		for (int i = 0; i < len; i++) {
			g1[i] = g1[i] * fb[i];
			g2[i] = g2[i] * fb[i];	
		}
		g1.FFT(-1), g2.FFT(-1);
		int MO = 1ll * mo * mo % mod;
		for (int i = 0; i < len; i++) 
			f[i].x = (1ll * (long long)round(g1[i].x) % mod * MO % mod + 1ll * (long long)round(g2[i].x) % mod * mo % mod 
			+ (long long)round(g2[i].y) % mod + 1ll * (long long)(round)(g1[i].y) % mod * mo % mod), 
			f[i].x = (long long)(round(f[i].x)) % mod;
		f.resize(t);
		return f;
	}
int qpow(int x, int k, int p) {
	int res = 1;
	while(k) {
		if(k & 1)
			res = 1ll * res * x % p;
		x = 1ll * x * x % p, k >>= 1;
	}
	return res;
}
int pw[maxn / 6], pi[maxn / 6], fac[maxn / 5000], tot;
int getroot(int n) {
	int x = n - 1;
	for (int i = 2; i * i <= x; i++) {
		if(x % i == 0) {
			fac[++tot] = i;
			while(x % i == 0)
				x /= i;
		}
	}
	if(x > 1)
		fac[++tot] = x;
	for (int i = 2; i < n; i++) {
		bool f = 1;
		for (int j = 1; j <= tot; j++) {
			if(qpow(i, (n - 1) / fac[j], n) == 1)
				f = 0;
			if(!f)
				break;
		}
		if(f)
			return i;
	}
}
void prework() {
	gb = getroot(mod);
	pw[0] = 1;
	for (int i = 1; i <= n; i++)
		pw[i] = 1ll * pw[i - 1] * gb % mod;
	gi = qpow(gb, n - 1, mod);
	pi[0] = 1;
	for (int i = 1; i <= n; i++)
		pi[i] = 1ll * pi[i - 1] * gi % mod;
}
Poly1 f, g;
int tx = 1;
int cal(int x) {
	return 1ll * x * (x - 1) / 2 % n;
}
struct Poly2 {
	vector<int> a;
	int size() {
		return a.size();
	}
	void resize(int N) {
		a.resize(N);
	}
	int& operator[](int x) {
		return a[x];
	}
	void DFT(int flag) {
		int n = a.size();
		f.resize(0), g.resize(0);
		f.resize(n), g.resize(2 * n);
		if(flag == 1) {
			for (int i = 0; i < n; i++)
				f[i].x = 1ll * a[i] * pi[cal(i)] % mod;
			for (int i = 0; i < 2 * n; i++)
				g[i].x = pw[cal(i)];
		}
		else {
			for (int i = 0; i < n; i++)
				f[i].x = 1ll * a[i] * pw[cal(i)] % mod;
			for (int i = 0; i < 2 * n; i++)
				g[i].x = pi[cal(i)];
		}
		reverse(f.a.begin(), f.a.end());
		f = f * g;
		for (int i = n - 1; i < 2 * n - 1; i++) {
			if(flag == 1)
				a[i - n + 1] = (long long)round(f[i].x) % mod * pi[cal(i - n + 1)] % mod;
			else
				a[i - n + 1] = (long long)round(f[i].x) % mod * pw[cal(i - n + 1)] % mod;
		}
		if(flag == -1) {
			int inv = qpow(n, mod - 2, mod);
			for (int i = 0; i < n; i++)
				a[i] = 1ll * a[i] * inv % mod;
		}
	}
} ft, gt;
signed main() {
//	freopen("test.in", "r", stdin);
//	freopen("std.out", "w", stdout);
	ios::sync_with_stdio(false);
	cin >> n >> c;	
	ft.resize(n), gt.resize(n);
	mod = n + 1; prework();
	for (int i = 0; i < n; i++)
		cin >> ft[i], ft[i] %= mod;
	for (int i = 0; i < n; i++)
		cin >> gt[i], gt[i] %= mod;
	ft.DFT(1), tx = 0, gt.DFT(1);
	for (int i = 0; i < n; i++)
		ft[i] = 1ll * ft[i] * qpow(gt[i], c, mod) % mod;
	ft.DFT(-1);
	for (int i = 0; i < n; i++)
		cout << ft[i] << '\n';
	return 0;
}
posted @ 2025-07-27 17:15  LUlululu1616  阅读(20)  评论(0)    收藏  举报