多项式

前言

kls:主播你好,请问算法竞赛中还有多项式技术吗。

内容

多项式乘法

对于两个多项式:

  1. \(A(x) = a_0 + a_1x + a_2x^2+\cdots+a_nx^n\)
  2. \(B(x) = b_0 + b_1x + b_2x^2+\cdots+b_mx^m\)

直接暴力做乘法是 \(O(nm)\),有没有更优的做法?

我们知道由 \(n + 1\) 个点可以确定一个 \(n\) 次多项式,那么如果我们能将这两个多项式转变为若干个点再相乘,那么相乘这个过程就是 \(O(n)\) 了!

上述也叫系数表示法转换为点值表示法,如何快速的转化呢?。这个叫 DFT。

\(O(n \log n)\) 解决上述乘法的算法就是大名鼎鼎 FFT(快速傅里叶变换)。

FFT

复数

定义:\(z = x + iy\),其中 \(i^2 = -1\)\(x, y \in \mathbb{R}\)

运算:

  1. 加法,\(z_0 + z_1 = (x_0 + x_1) + i(y_0 + y_1)\)
  2. 减法,\(z_0 + z_1 = (x_0 - x_1) + i(y_0 - y_1)\)
  3. 乘法,\(z_0 z_1 = (x_0x_1 - y_0y_1) + i(x_0y_1 + x_1y_0)\)

单位圆

定义:复平面上圆心在原点且半径为 \(1\) 的圆。

圆上的点:\(z = \cos\theta + i\sin\theta\)。(斜边即为半径为 \(1\)

我们把圆 \(n\) 等分,那么从 \(x\) 轴逆时针第 \(k\) 个点为 \(z = \cos\frac{2\pi k}{n} + i\sin\frac{2\pi k}{n}\),其中 \(k \in [0, n)\),记作 \(\omega_{n}^{k}\)

FFT 要求 \(n = 2^{b \in \mathbb{N}}\)

对于我们的 \(\omega_{n}^{k}\),有如下四条性质,保证了他能顺利进行快速转换:(都可以借助几何理解或代入去算)

  1. 指数性,\(\omega_{n}^{k} \omega_{n}^{m} = \omega_{n}^{k + m}\),也可以进行其他指数上的运算。
  2. 周期性,\(\omega_{n}^{k} = \omega_{n}^{k + n}\)
  3. 对称性,\(\omega_{n}^{k + \frac{n}{2}} = -\omega_{n}^{k}\)
  4. 折半性,\(\omega_{n}^{2k} = \omega_{\frac{n}{2}}^{k}\),这保证了我们时间复杂度正确。

DFT

现在考虑如何从系数转换为点值。

比如这个多项式 \(A(x) = a_0 + a_1x + a_2x^2+\cdots+a_{n-1}x^{n-1}\)

我们把奇偶位置的项分离,即 \(A(x) = (a_0 + a_2x^2 + \cdots + a_{n-2}x^{n-2}) + x(a_1 + a_3x^2 + \cdots + a_{n-1}x^{n-2})x\)

\(A_1(x) = (a_0 + a_2x + \cdots + a_{n-2}x^{\frac{n}{2} - 1})\)

\(A_2(x) = (a_1 + a_3x + \cdots + a_{n-1}x^{\frac{n}{2} - 1})\),那么 \(A(x) = A_1(x^2) + A_2(x^2)\color{red}x\)

\(\omega_{n}^{k}(k < \frac{n}{2})\) 代入得

\(A(\omega_{n}^{k}) = A_1(\omega_{n}^{2k}) + A_2(\omega_{n}^{2k})\omega_{n}^{k} = A_1(\omega_{\frac{n}{2}}{k}) + A_2(\omega_{\frac{n}{2}}{k}) \omega_{n}^{k}\)

\(\omega_{n}^{k + \frac{n}{2}} = \omega_{n}^{-k}\) 代入得

\(A(\omega_{n}^{k + \frac{n}{2}}) = A_1(\omega_{n}^{2k}) + A_2(\omega_{n}^{2k})\omega_{n}^{-k} = A_1(\omega_{\frac{n}{2}}{k}) - A_2(\omega_{\frac{n}{2}}{k})\omega_{n}^{k}\)

递归可做了,时间复杂度 \(O(n \log n)\),然后我们就从系数表示转换成了点值表示。

点击查看 DFT 代码
void DFT(complex<db> A[], int n) {
	if (n == 1) return ;
	complex<db> A1[n / 2], A2[n / 2];
	for (int i = 0; i < n / 2; ++i)
		A1[i] = A[i * 2], A2[i] = A[i * 2 + 1];
	DFT(A1, n / 2), DFT(A2, n / 2);
	complex<db> w1({cos(2 * PI / n), sin(2 * PI / n)});
	complex<db> wk({1, 0});
	for (int i = 0; i < n / 2; ++i) {
		A[i] = A1[i] + A2[i] * wk;
		A[i + n / 2] = A1[i] - A2[i] * wk; 
		wk = wk * w1;
	}	
}

IDFT

我们现在已经将点值乘在一起了,但是题目肯定是想知道最终的系数,所以还要进行逆变换。

还是对于一个多项式 \(A(x) = a_0 + a_1x + a_2x^2+\cdots+a_{n-1}x^{n-1}\),我们依次代入 \(\omega_{n}^{k}\),可以得到 \(n\) 个点 \(y_k = \sum\limits_{i=0}^{n-1} a_i (\omega_{n}^{k})^i\)

我们以 \(y_i\) 再构造一个多项式 \(B(x) = y_0 + y_1x + y_2x^2+\cdots+y_{n-1}x^{n-1}\)

然后我们再依次代入 \(\omega_{n}^{-k}\),即倒数,又得到 \(n\) 个点 \(z_i\)

那么 \(z_k = \sum\limits_{i=0}^{n-1} y_i (\omega_{n}^{-k})^i = \sum\limits_{i=0}^{n-1} \sum\limits_{j=0}^{n-1} a_j (\omega_{n}^{i})^j (\omega_{n}^{-k})^i = \sum\limits_{j=0}^{n-1}a_j\sum\limits_{i=0}^{n-1} (\omega_{n}^{j-k})^i\)

内层的和式为等比数列,所以易得当 \(j = k\) 时它等于 \(n\),否则等于 \(0\)

所以 \(z_k = na_k \to a_k = \frac{a_k}{n}\),欸?是不是就逆 dft 完了?

相当于重新代入一个倒数,倒数显然就是 \(k\) 取反,所以只需要在 dft 上微调就行。

然后你把他们结合起来,是不是 FFT?当然,为了方便的递归折半,我们要将多项式补成 \(2^n\) 次,高次不足补 \(0\)

点击查看 DFT/IDFT 代码
void dft(complex<db> A[], int n, int op) { // 1 为 dft,-1 为 idft
	if (n == 1) return ;
	complex<db> A1[n / 2], A2[n / 2];
	for (int i = 0; i < n / 2; ++i)
		A1[i] = A[i * 2], A2[i] = A[i * 2 + 1];
	dft(A1, n / 2, op), dft(A2, n / 2, op);
	complex<db> w1({cos(2 * PI / n), sin(2 * PI / n) * op});
	complex<db> wk({1, 0});
	for (int i = 0; i < n / 2; ++i) {
		A[i] = A1[i] + A2[i] * wk;
		A[i + n / 2] = A1[i] - A2[i] * wk; 
		wk = wk * w1;
	}	
}
点击查看 P3803【模板】多项式乘法 代码
const db PI = acos(-1);

int n, m;
complex<db> A[MAXN], B[MAXN];

void dft(complex<db> A[], int n, int op) {
	if (n == 1) return ;
	complex<db> A1[n / 2], A2[n / 2];
	for (int i = 0; i < n / 2; ++i)
		A1[i] = A[i * 2], A2[i] = A[i * 2 + 1];
	dft(A1, n / 2, op), dft(A2, n / 2, op);
	complex<db> w1({cos(2 * PI / n), sin(2 * PI / n) * op});
	complex<db> wk({1, 0});
	for (int i = 0; i < n / 2; ++i) {
		A[i] = A1[i] + A2[i] * wk;
		A[i + n / 2] = A1[i] - A2[i] * wk; 
		wk = wk * w1;
	}	
}

signed main() {
	ios::sync_with_stdio(0);
	cin.tie(0);
	cin >> n >> m;
	for (int i = 0; i <= n; ++i) cin >> A[i];
	for (int i = 0; i <= m; ++i) cin >> B[i];
	for (m = n + m, n = 1; n <= m; n <<= 1); 
	dft(A, n, 1), dft(B, n, 1);
	for (int i = 0; i < n; ++i) A[i] = A[i] * B[i];
	dft(A, n, -1);
	for (int i = 0; i <= m; ++i)
		cout << (int)((A[i].real()) / n + 0.5) << " ";
	return 0;
}

蝴蝶变换

递归还是有点慢了,考虑迭代。

我们每次都会奇偶分离,位置也会发生改变,但是注意到一个项最初的位置与最终的位置的二进制刚好翻转,然后就可以改成迭代了。

\(r_i\) 表示第 \(i\) 项最终在 \(r_i\) 位置,有转移式 \(r_i = \frac{r_{\frac{i}{2}}}{2} + (i \& 1) (\frac{n}{2})\),可以举例验证一下就懂为什么了,记得把除以 \(2\) 看成右移一位。

然后我们先把每个项放到最终位置,再按照刚刚自底而上合并就可以了。

点击查看迭代型 FFT 代码
#define comp complex<db>

void change(comp *A, int n) {
	for (int i = 0; i < n; ++i) 
		if (i < r[i]) swap(A[i], A[r[i]]);
}

void dft(comp *A, int n, int op) {
	change(A, n);
	for (int m = 2; m <= n; m <<= 1) {
		comp w1({cos(2 * PI / m), sin(2 * PI / m) * op});
		for (int i = 0; i < n; i += m) {
			comp wk({1, 0});
			for (int j = 0; j < m / 2; ++j) {
				comp x = A[i + j], y = A[i + j + m / 2] * wk;
				A[i + j] = x + y, A[i + j + m / 2] = x - y;
				wk *= w1;
			}
		}
	}	
}

void fft(comp *A, comp *B, int n, int m) {
	for (m = n + m, n = 1; n <= m; n <<= 1);
	for (int i = 0; i < n; ++i) 
		r[i] = (r[i >> 1] >> 1) + (i & 1) * (n / 2);
	dft(A, n, 1), dft(B, n, 1);
	for (int i = 0; i < n; ++i) A[i] = A[i] * B[i];
	dft(A, n, -1);
	for (int i = 0; i <= m; ++i)
		cout << (int)(A[i].real() / n + 0.5) << " ";	
}

NTT

复数掉精度,还很难在模意义下搞,有没有替代品?

有的,有叫原根的东西,在这里不作介绍,因为 NTT 不用搞懂它。

它同样满足直接所说的四个性质,假设在模 \(m\) 意义下,原根为 \(g\),只是变成了 \(g_{n}^{k} = g^{\frac{(p-1)k}{n}}\)

然后就没区别了,但是此时模数必须要能表示为 \(r2^k + 1\),比如说 \(998244353 = 119 \times 2^{23} + 1\),他的原根为 \(3\),直接记就行。

点击查看 NTT 代码

\(g\) 是原根,\(gi\) 是原根的逆元。

void NTT(LL *A, int n, int op) {
	for (int i = 0; i < n; ++i) 
		if (i < R[i]) swap(A[i], A[R[i]]);
	for (int m = 2; m <= n; m <<= 1) {
		LL g1 = qpow(op == 1 ? g : gi, (P - 1) / m);
		for (int i = 0; i < n; i += m) {
			LL gk = 1;
			for (int j = 0; j < m / 2; ++j, gk = gk * g1 % P) {
				LL x = A[i + j], y = gk * A[i + j + m / 2] % P;
				A[i + j] = (x + y) % P, A[i + j + m / 2] = (x - y + P) % P;
			}
		}
	}	
    if (op != 1) {
		int inv = ksm(n, mod - 2);
		for (int i = 0; i < n; ++i) (A[i] *= inv) %= mod;
	} 
}	

拉格朗日插值

lagrange 插值,简称拉插,可以解决给定 \(n + 1\) 个点,求出其对应的 \(n\) 次多项式的 \(f(k)\) 值的问题。(瞎口胡的,意思理解了就行)

首先,我们知道 \(n + 1\) 个点,相当于知道 \(n + 1\) 个方程,可以高斯消元 \(O(n^3)\) 求这个问题。

但是高斯消元不仅慢还难写,拉插提供了一个公式 \(f(x) = \sum\limits_{i=1}^{n} y_i \prod\limits_{i \ne j} \frac{x - x_j}{x_i - x_j}\)

如何理解呢?首先,我们要知道 \(n + 1\) 个点能唯一确定一个 \(n\) 次多项式,就像两点确定一条直线,三点确定一条抛物线。

好,然后你就去看那个式子然后你就能看懂了。

比如我代 \(x_k\) 进去,当 \(i \ne k\) 时,内部和式一定为 \(0\),当 \(i = k\) 时,内部和式一定为 \(1\)

好没了。

点击查看拉插代码
ll lagrange(ll *X, ll *Y, ll x) {
	ll ans = 0;
	for (int i = 1; i <= n; ++i) {
		ll a = y[i], b = 1;
		for (int j = 1; j <= n; ++j) {
			if (i ^ j) {
				(a *= x - X[j]) %= mod;
				(b *= X[i] - X[j]) %= mod;
			}
		}
		b = ksm(b, mod - 2);
		(ans += a * b % mod) %= mod;
	}		
	return (ans + mod) % mod;
}

求乘法逆、开方等

以乘法逆为例,给定多项式 \(F(x)\),如何求另一个多项式 \(G(x)\) 满足 \(F(x)G(x) \equiv 1 \pmod{x^n}\),即前面 \(n\) 项相乘为 \(1\)

牛顿迭代

在数字中的牛顿迭代用于求解形如 \(f(x) = 0\) 的方程的解。

其有迭代公式 \(f(x_1) = x_0 - \frac{f(x_0)}{f'(x_0)}\)

具体的,我们要先猜一个大概的解,然后通过不停迭代逼近精确解,而这个迭代是指数级逼近的。

求逆元

以下暂且先忽略模意义,应用到多项式中,我们需要表示出 \(f(x) = 0\) 这种形式,因为 \(F(x)G(x) = 1\),因此我们可以定义 \(F(G) = F(x)G(x) - 1 = 0\),其中 \(F\) 已知。

然后直接用迭代公式,\(G_1 = G_0 - \frac{F(G_0)}{F'(G_0)}\)

里面的 \(F'(G)\) 即对 \(F(G)\) 求导,我们现在没必要具体知道它的一切,只需要知道对于多项式求导而言:

  1. 常数项变为 \(0\)
  2. 对每个项进行求导,即 \(a_ix^i = i a_i x^{i-1}\)
  3. 每个项导数相加就为多项式的导数。

所以对 \(F(G) = F(x)G(x) - 1\) 进行求导,常数项变为 \(0\)\(G(x)\) 变为 \(G(x)^0 = 1\),所以 \(F'(G) = F\)

以逆元为例,我们继续推这个式子:\(G_1 = G_0 - \frac{F(G_0)}{F'(G_0)} = G_0 - (F \cdot G_0 - 1)\cdot F^{-1} = G_0 - (F \cdot G_0 - 1)\cdot G_0 = 2G_0 - F \cdot G_0^2\)

直接上牛顿迭代套上 NTT。

那我们一开始猜什么值呢?直接用 \(F\) 的常数项的逆元,可以说,数上用牛顿迭代是小数点更精确,多项式上则是项越来越多。

点击查看 P4238 【模板】多项式乘法逆 代码
void DFT(ll *A, int n, int op) {
	for (int i = 0; i < n; ++i) 
		if (i < r[i]) swap(A[i], A[r[i]]);
	for (int m = 2; m <= n; m <<= 1) {
		ll g1 = ksm(op == 1 ? g : gi, (mod - 1) / m);
		for (int i = 0; i < n; i += m) {
			ll gk = 1;
			for (int j = 0; j < m / 2; ++j, gk = gk * g1 % mod) {
				ll x = A[i + j], y = gk * A[i + j + m / 2] % mod;
				A[i + j] = (x + y) % mod, A[i + j + m / 2] = (x - y + mod) % mod;
			}
		}
	}	
	if (op != 1) {
		int inv = ksm(n, mod - 2);
		for (int i = 0; i < n; ++i) (A[i] *= inv) %= mod;
	} 
}	

void INV(ll *F, ll *G, int n) { // 将 F 的逆赋给 G  
	int lim = 1;
	while (lim < n) lim <<= 1;
	static ll _F[MAXN];
	G[0] = ksm(F[0], mod - 2);
	for (int p = 2; p <= lim; p <<= 1) { // p 为即将扩展到的长度 
		int len = p << 1; 
		for (int i = 0; i < len; ++i) 
			r[i] = (r[i >> 1] >> 1) + (i & 1) * (len >> 1);
		for (int i = 0; i < len; ++i)
			_F[i] = i < p ? F[i] : 0; 
		DFT(G, len, 1), DFT(_F, len, 1);
		for (int i = 0; i < len; ++i) 
			G[i] = ((2 - _F[i] * G[i] % mod + mod) % mod) * G[i] % mod; // G1 = 2G0 - F*G0*G0
		DFT(G, len, 0);
		for (int i = p; i < len; ++i) G[i] = 0;
	}
}

求开方

再比如说求开方,因为 \(F(x) = G(x)^2\),所以设 \(F(G) = F(x) - G(x)^2 = 0\),直接套模板就行。

多项式除法

\(n\) 次多项式 \(F(x)\) 除以 \(m\) 次多项式 \(G(x)\) 要求找到两个多项式 \(Q(x), R(x)\) 满足:\((n \ge m)\)

  1. \(G(x)Q(x)+R(x)=F(x)\)
  2. \(Q(x)\)\(n - m\) 次多项式,\(R(x)\) 的次数小于 \(m\)

\(f_R(x)\) 表示 \(f(x)\) 系数翻转后的多项式,即常数项系数与最高项系数交换以此类推,\(deg(f)\) 表示 \(f(x)\) 最高次幂。

显然有 \(f_R(x) = x^{deg(f)} f(\frac{1}{x})\)

所以代入之前的定义中可得 \(F_R(x) = G_R(x)Q_R(x) + x^{n - deg(R)}R_R(x)\),因为 \(deg(R) < m\),所以 \(n - deg(R) > n - m \to n - deg(R) \ge n - m + 1\)

所以有 \(F_R(x) = G_R(x)Q_R(x) \pmod{x^{n-m + 1}}\),然后就可以求出模 \(x^{n-m+1}\) 意义下的 \(Q_R(x)\),注意到 \(deg(Q) = n - m\),所以说这个 \(Q_R(x)\) 就是等于它自己。

然后再代入定义中求出 \(R(x)\)

点击查看多项式除法代码
void ntt(ll *A, ll *B, ll *C, int n, int m, int k) {
	int len = n + m, N = 1;
	while (N <= len) N <<= 1;
	static ll tA[MAXN], tB[MAXN];
	for (int i = 0; i < N; ++i) {
		tA[i] = i <= n ? A[i] : 0;
		tB[i] = i <= m ? B[i] : 0; 
	}
	dft(tA, N, 1), dft(tB, N, 1);
	for (int i = 0; i < N; ++i) 
		tA[i] = (tA[i] * tB[i]) % mod;
	dft(tA, N, 0);
	memcpy(C, tA, (k + 1) * sizeof(ll));	
}

void inv(ll *F, ll *G, int n) {
	int lim = 1; 
	while (lim <= n) lim <<= 1;
	static ll _F[MAXN];
	G[0] = ksm(F[0], mod - 2);
	for (int p = 2; p <= lim; p <<= 1) {
		int len = p << 1;
		for (int i = 0; i < len; ++i)
			_F[i] = (i < p ? F[i] : 0);
		dft(_F, len, 1), dft(G, len, 1);
		for (int i = 0; i < len; ++i)
			G[i] = ((2 * G[i] % mod - _F[i] * G[i] % mod * G[i] % mod) + mod) % mod;
		dft(G, len, 0);
		for (int i = p; i < len; ++i) G[i] = 0;
	}	
}

void rev(ll *A, int n) {
	for (int i = 0; i <= n / 2; ++i)	
		swap(A[i], A[n - i]);
}

void div(ll *F, ll *G, int n, int m) {
	static ll invG[MAXN], Q[MAXN], R[MAXN], GQ[MAXN]; 
	rev(F, n), rev(G, m);
	inv(G, invG, n - m);
	ntt(F, invG, Q, n, n - m, n - m);
	rev(Q, n - m);
	for (int i = 0; i <= n - m; ++i) cout << Q[i] << " ";
	rev(F, n), rev(G, m);
	ntt(Q, G, GQ, n - m, m, m - 1);
	for (int i = 0; i < m; ++i) 
		R[i] = (F[i] - GQ[i] + mod) % mod;
	cout << '\n';
	for (int i = 0; i < m; ++i) cout << R[i] << " ";
}

分治 FFT

还是很好理解的,其实就是在一个递推式里,递归每层计算左半对右半的贡献。

P4721【模板】FFT 分治为例,我们定义 \(p(l, r)\) 表示正在求 \(l \sim r\) 区间里的 \(f\) 值。

对于一个区间 \([l,r]\) 我们先递归处理他的左区间 \([l, mid = \frac{l + r}{2}]\),然后考虑前半部分对后半部分的贡献:

记对 \(f_x\) 的贡献为 \(w_x\),那么 \(w_x = \sum\limits_{j=l}^{mid} f_{j}g_{x-j}\),因为此时右半部分的 \(f\) 还没求等于 \(0\),不妨把上界改为 \(x\)

所以 \(w_x = \sum\limits_{j=l}^{x-1} f_{j}g_{x-j} = \sum\limits_{j=0}^{x - l - 1} f_{j + l} x_{x - j - l}\)

\(a_i = f_{i+l}\)\(b_i = f_{i+1}\),所以 \(w_x = \sum\limits_{j=0}^{x - l - 1} a_j b_{x - j - l - 1}\),好了,现在可以卷积了。

卷完之后加在 \(f_x\) 里再递归处理右区间,可以证明时间复杂度是 \(O(n \log^2 n)\) 的。

点击查看分治 FFT 代码
void cdq(ll *f, ll *g, int l, int r) {
	if (l == r) return ;
	int mid = (l + r) >> 1, lim, len;
	cdq(f, g, l, mid);
	for (lim = r - l + 1, len = 1; len <= lim; len <<= 1);
	memset(A, 0, len * sizeof(ll));
	memset(B, 0, len * sizeof(ll));
	for (int i = l; i <= mid; ++i) A[i - l] = f[i];
	for (int i = 1; i <= r - l; ++i) B[i - 1] = g[i];
	ntt(A, B, len);
	for (int i = mid + 1; i <= r; ++i) f[i] = (f[i] + A[i - l - 1]) % mod;
	cdq(f, g, mid + 1, r);
}

【持续更新中,还在学习】

posted @ 2026-02-04 19:26  Statax  阅读(20)  评论(3)    收藏  举报