多项式乘法(FFT)学习笔记
Reference:
一、什么是 FFT
快速傅里叶变换(Fast Fourier Transform)是一种用于在 \(O(n \log n)\) 实现多项式“点值表示法”和“系数表示法”之间变换的算法。
而往往在把系数表示法转化为了点值表示法后,我们可以更方便地解决一些问题,比如说多项式乘法。
二、思想
以 P3803 【模板】多项式乘法(FFT) 为例。
1. 多项式相乘
其实很简单,这里给出一个形式化定义。
设 \(f(x)\) 的系数为向量 \((a_0, a_1, \dots, a_{n-1})\),\(g(x)\) 的系数为 \((b_0, b_1, \dots, b_{m-1})\),则 \(f(x)*g(x)\) (亦称 \(f(x), g(x)\) 的卷积)的系数 \((c_0, c_1, \dots, c_{n+m-1})\) 满足:
2. 点值表示法
如果要在 \(O(n^2)\) 的时间内解决这道题该怎么办?很简单,根据“多项式相乘的形式化表述”中的那个式子递推即可。
为了优化复杂度,前人们引入了多项式的一个新的表示法——点值表示法。具体地,每一个多项式可以看成一个函数,那么每一个取值 \(x\) 都可以对应一个点 \((x, f(x))\)。明显,对于一个 \(n-1\) 次的多项式,当我们取 \(n\) 个不同点时,即可唯一地确定一个多项式。
引入点值表示法对于复杂度的优化有什么益处呢?如果现在已经有了两个多项式的点值表示法,且每个点取的 \(x\) 是相同的,那么它们的卷积的点值表示法即为 \((x, f(x)*g(x))\)。这一计算可以在 \(O(n+m)\) 的时间内完成。(当然,由于相乘后多项式次数增高,对于 \(f(x), g(x)\) 分别计算点值表示法时,不能只各取 \(n, m\) 个点,都要至少取到 \(n+m\) 个才可以直接相乘计算它们的卷积。)
但是由于直接计算点值表示法的复杂度仍旧为 \(O(n^2)\) 的,这看上去对复杂度并没有起到任何的优化效果。于是让我们请出神奇的快速傅里叶变换,看它是如何在 \(O(n \log n)\) 的时间内实现系数表示法和点值表示法之间的转换的。
3. 一点点前置数学知识
-
三角函数
-
向量
-
复数
(前两者由于笔者在攥稿时已经在数学课上学习过了,暂且略去。)
-
虚数单位:定义常数 \(i\) 满足 \(i^2 = -1\),称为虚数单位。
-
复数:定义形如 \(a+bi(a, b \in \mathbf{R})\) 的数为复数。
-
实数和虚数:当且仅当 \(b = 0\) 时,它是实数;当且仅当 \(b \neq 0\) 时,它是虚数;当且仅当 \(a = 0\) 且 \(b \neq 0\) 时,它是纯虚数。

- 复平面:我们已知实数与数轴上的数一一对应。由于一个复数可以唯一地写作 \(a+bi(a, b \in \mathbf{R})\),所以每一个复数都可以对应平面上的一个向量——\((a, b)\)。称这个平面为复平面。\(x\) 轴是实数轴,即实轴;\(y\) 轴为纯虚数轴,即虚轴。

-
模(绝对值):定义复数所对应向量的模,为该复数的模(又称绝对值)。形式化地,可表示为 \(\sqrt{a^2+b^2}\)。
-
幅角:以 \(x\) 轴正半轴为始边,复数对应的向量所在的射线为终边,形成的角。
-
三角表示法:一个复数还可以用模长 + 幅角的形式表示。设模长为 \(r\),幅角为 \(\theta\),则它的三角表示法为:\(r(\cos\theta + i\sin\theta)\)。
-
复数的加、减:和向量的加减法则是一样的。
-
复数的乘法:
-
代数上来说,设 \(x = a+bi, y = c+di\),则有:
\[\begin{aligned} x \times y &= (a+bi)(c+di) \\ &= ac+bci+adi+bdi^2 \\ &= ac+bci+adi-bd \\ &= (ac-bd)+(bc+ad)i \end{aligned} \] -
几何上来说,它是模长相乘、幅角相加。设复数 \(x = r_1(\cos\theta_1 + i\sin\theta_1), y = r_2(\cos\theta_2 + i\sin\theta_2)\),则有:
\[\begin{aligned} x \times y &= r_1r_2(\cos\theta_1 + i\sin\theta_1)(\cos\theta_2 + i\sin\theta_2) \\ &= r_1r_2(\cos\theta_1\cos\theta_2 + i\cos\theta_1\sin\theta_2 + i\sin\theta_1\cos\theta_2 + i^2\sin\theta_1\sin\theta_2) \\ &= r_1r_2[(\cos\theta_1\cos\theta_2- \sin\theta_1\sin\theta_2) + i(\cos\theta_1\sin\theta_2 + \sin\theta_1\cos\theta_2)] \\ &= r_1r_2[\cos(\theta_1+\theta_2) + i\sin(\theta_1+\theta_2)] \end{aligned} \]
-
-
单位根:定义满足 \(x^n = 1\) 的 \(x\) 为 \(n\) 次单位根。注意,同次的单位根不止一个。这里姑且将所有 \(n\) 次单位根的集合记作 \(E_n\)。
-
本原单位根:定义 \(x\in E_n\),满足 \(E_n \subseteq \{ y | y = x^k, k\in\mathbf{N} \}\)(说人话,就是它的非负整数次幂可以生成所有的别的 \(n\) 次单位根),为 \(n\) 次本原单位根。记作 \(\omega_n\)。
-
本原单位根的几何意义:在复平面上作一个单位圆。从 \(x\) 轴开始,把单位圆 \(n\) 等分,圆上分出来的点中幅角最小的那一个就是 \(\omega_n\)。其余的点,则都可以写作 \(\omega_n^k\)(\(k\) 为 \(0\sim n-1\) 的整数),这一结论由于复数乘法的几何意义是显然的。

(引自 @FlashHu 的博客)
- \(\omega_n^k\) 的值:根据复数乘法的几何意义,明显可以得到:
【注意下面两条性质,它们是快速傅里叶变换的关键。】
-
性质一:\(\omega_{an}^{ak} = \omega_n^k\)
证明:\(\omega_{an}^{ak} = \cos\frac{ak*2\pi}{an} + i\sin\frac{ak*2\pi}{an} = \omega_n^k\)
-
性质二:\(\omega_n^{k+\frac{n}{2}} = -\omega_n^k\)
证明:\(\omega_n^{k+\frac{n}{2}} = \omega_n^k * \omega_n^{\frac{n}{2}} = \omega_n^k * \omega_2^1 = -\omega_n^k\)
4. 快速傅里叶变换(FFT)
简单概括一下 FFT 的思想:将所有的 \(\omega_n^k\)(\(k\) 为 \(0\sim n-1\) 的整数)代入多项式,并且利用上面的两条性质,进行分治优化。
【注意:以下推导默认 \(n\) 为偶数】
先设当前有一个多项式:
对它的下标进行奇偶性分类:
给右边的式子提个公因数,将左右化得更为相似:
设:
则:
这样,每次计算 \(f(x)\) 时,我们就可以递归分别计算两个项数更少的 \(f_1(x)\) 和 \(f_2(x)\),再加起来得到答案了。
看上去还是没用?让我们尝试随便代入一个值 \(\omega_n^k (k < \frac{n}{2})\):
再代入一个值 \(\omega_n^{k+\frac{n}{2}} (k < \frac{n}{2})\):
可以发现,前一个式子和后一个式子几乎是一样的。因此,当我们用递归计算 \(f_1\)、\(f_2\) 的方法计算出 \(f(\omega_n^k)\) 时,我们完全可以顺便求出 \(f(\omega_n^{k+\frac{n}{2}})\)。如此,计算范围就缩小了一半。
而在递归分别计算 \(f_1\) 和 \(f_2\) 的过程中,也同样可以按照上面的方法,缩小一半范围计算。
当然,运用上述方法的必要条件是项数为偶数。为了保证每一层递归时项数都是偶数,我们要将项数设置为一个大于等于原项数的 \(2\) 正整数次幂,并补上一些系数为 \(0\) 的项。
一直递归,层数最多为 \(\log n\)。总复杂度 \(O(n \log n)\)。
5. 快速傅里叶逆变换(IFFT)
快速傅里叶变换让我们成功地将系数表示法转化为了点值表示法。但是在我们运用点值表示法快速地得到了新多项式时,还需要把点值表示法转化回系数表示法。这时就需要用到快速傅里叶逆变换。
快速傅里叶逆变换的思路非常奇怪。它的思路是:将当前得到的点值表示,当作一个多项式的系数,再对这个新多项式做一遍快速傅里叶变换求点值表示。可以证明,这个二次快速傅里叶变换所得到的的结果和原本的系数表示法之间存在某种关系。
这里为了帮助我自己理解,手推(抄)了一遍 dalao@自为风月马前卒 给出的快速傅里叶逆变换的证明。
\((y_0, y_1, \dots, y_{n-1})\) 为多项式 \((a_0, a_1, \dots, a_{n-1})\) 在 \(x\) 取 \((\omega^0_n, \omega^1_n, \dots, \omega^{n-1}_n)\) 时的点值表示(亦称傅里叶变换)。形式化地,它满足:
设有一向量 \((c_0, c_1, \dots, c_{n-1})\) 为 \((y_0, y_1, \dots, y_{n-1})\) 在 \(x\) 取 \((\omega^0_n, \omega^{-1}_n, \dots, \omega^{-(n-1)}_n)\) 时的点值表示。形式化地,它满足:
然后开始推导。
观察上式中第二个 \(\sum\) 的内容,它是一个关于 \(\omega_n^{j-k}\) 的等比数列。当 \(j-k \ne 0\) 时,根据等比数列求和公式,有:
而当 \(j-k = 0\) 时,有 \(\omega_n^{j-k} = 1\),那么求和的结果即为 \(n\)。
因此有:
即:
在学了 skc 的网课后,发现这还有另外一种从矩阵角度的理解方法:


三、实现
1. 递归版
我们来浅浅总结一下整体的流程:
-
先求出一个 \(N\),满足为第一个大于 \(n+m\) 的 \(2\) 的正整数次幂,令它为新项数。并为两个多项式补上一些系数为 \(0\) 的项。
-
对多项式 \(f(x), g(x)\) 分别跑一次 FFT,得到二者的点值表示法。
-
将二者的点值表示法逐个相乘,得到二者卷积的点值表示法。
-
对卷积的点值表示法跑一遍 IFFT,得到它的系数表示法。
FFT 的流程:
-
当前递归到的是一个项数为 \(n\) 的多项式。
-
将该多项式的系数按照奇偶性拆为两个多项式,分别递归计算值。
-
通过计算到的两个多项式的答案,合并得到当前的答案。
IFFT 的流程:
-
将原本的单位根 \(\omega_n\) 改为 \(\omega_n^{-1}\),对点值表示法计算二次 FFT。一般直接使用性质 \(\omega_n^{-k} = \omega_n^{-k} * \omega_n^n = \omega_n^{n-k}\) 来计算 \(\omega_n^{-k}\)。
-
将二次 FFT 后得到的值全都除以 \(n\),即得到系数表示法。
直接使用递归法由于 FFT 的常数过大,在洛谷上无法通过。
2. 递推版
回忆一下上文的递归式:
如果按照递归写法,它需要开很多数组 \(f, f1, f2, \dots\),而接下来介绍的递推法能够在只开一个数组的情况下解决问题。
观察下图:

可以发现两条巧妙的性质:
-
对于当前规模为 \(n\) 的递归层,\(f[k]\) 和 \(f[k+\frac{n}{2}]\) 两个位置在下一层递归中恰好对应 \(f_1[k]\) 和 \(f_2[k]\)。
-
最后结束时的序列为原序列的二进制翻转。
性质 1. 使得我们每一次从下往上递推时只需要调用下面的 \(f[k], f[k+\frac{n}{2}]\) 即可计算出 \(f[k], f[k+\frac{n}{2}]\)。这样就只需要开一个数组了。
性质 2. 使得我们可以轻松地得到递推的起始状态。
3. c++ STL complex 类
实现的时候我们可以选择自己封装一个复数类,也可以选择使用 c++ STL 自带的 complex 类。
- 定义
基本格式为 complex<T> x;,T 为一种浮点数类型。
初始化变量的值有很多种方式。下面列举了几种:
complex<double> a(3, 4);
complex<double> b = 3.0 + 4i;
complex<double> c = {3, 4};
complex<double> d(3);//只初始化实部
- 实部和虚部
使用 real(x), imag(x) 可以分别取出 x 的实部和虚部,但无法赋值。
使用 x.real(), x.imag() 也可以分别取出 x 的实部和虚部。并且,可以通过 x.real(a), x.imag(b) 的方式分别将 x 的实部和虚部赋值为 a 和 b。
- 运算
complex 类内置了复数加减乘除,以及与实数的加减乘除运算。我们只需要用运算符号正常运算即可。
- 输出
complex 类内置了输入输出流。直接写 cout<<x<<endl;,会输出形如 (real,imag) 的结果。
4. code
点击查看代码
#include<bits/stdc++.h>
#define cp complex<double>
using namespace std;
const int MAXN = (1<<21)+5;
const double Pi = acos(-1);//为确保精度,需要写成这样。
int n, m, N = 1, rev[MAXN];
cp f[MAXN], g[MAXN], h[MAXN];
inline void Prework(){//预处理二进制翻转。
while(N <= n+m) N <<= 1;
for(int i = log2(N)-1, t = 0; i >= 0; i--){
int x = t;
for(int j = 0; j <= x; j++) rev[++t] = (1<<i)|rev[j];
}
return;
}
inline void DFT(cp a[], int v){//v 表示当前是 w_n^k 还是 w_n^{-k}
for(int i = 0; i < N; i++) if(rev[i] < i) swap(a[rev[i]], a[i]);
for(int i = 2; i <= N; i <<= 1)//枚举当前层中,每一组的大小 i
for(int k = 0; k < i/2; k++){//枚举组内编号 k
//从逻辑上应该先枚举 j 再枚举 k,但是我为了只计算一次不同的 w_n^k,就调换了一下顺序。
cp w(cos(2*Pi/i*k), sin(2*Pi/i*k*v));
for(int j = 0; j < N; j += i){//枚举每个组的开头下标 j
cp x = a[j+k], y = a[i/2+j+k];
a[j+k] = x+w*y;
a[i/2+j+k] = x-w*y;
}
}
return;
}
inline void IDFT(cp a[]){
DFT(a, -1);
for(int i = 0; i < N; i++) a[i] /= N;
return;
}
int main(){
scanf("%d%d", &n, &m);
for(int i = 0; i <= n; i++){
int x; scanf("%d", &x);
f[i].real(x);
}
for(int i = 0; i <= m; i++){
int x; scanf("%d", &x);
g[i].real(x);
}
Prework();
DFT(f, 1), DFT(g, 1);
for(int i = 0; i < N; i++) h[i] = f[i]*g[i];
IDFT(h);
for(int i = 0; i <= n+m; i++) printf("%d ", int(round(real(h[i]))));
return 0;
}
四、NTT
NTT 和 FTT 不同的地方在于,它被用于解决模意义下的多项式乘法。
其实本质上的思路和 FFT 只能说是一模一样。只不过 NTT 中的“单位根”,是模意义下的单位根而已。
如何求出模意义下的单位根呢?首先需要看一下 原根与阶。设在该模数 \(p\) 下有原根 \(g\),可以证明,任意满足 \(x \equiv g^t \pmod p\) 的 \(x\),一定满足 \(t\text{ord}(x) = \text{ord}(g)\)。故存在一个单位根 \(\omega_n \equiv g^{\frac{\text{ord}(g)}{n}} = g^{\frac{\varphi(p)}{n}}\)(当然,前提条件一定是 \(n|\varphi(p)\))。
NTT 的题目一般用 \(p = 998244353\) 做模数,因为它是一个质数,所以有 \(\varphi(p) = p-1\),且满足 \(p-1 = 2^{23}\times7\times17\),所以在 \(n \le 2^{23}\) 的情况下都存在 \(n\) 次单位根。而这个模数的最小原根为 \(3\),可以直接用。
点击查看代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MAXN = (1<<21)+5;
const ll Mod = 998244353, Rt = 3;
int n, m, N = 1, rev[MAXN];
ll f[MAXN], g[MAXN], h[MAXN];
inline ll Quick_Power(ll x, ll p){
if(!p) return 1;
ll tmp = Quick_Power(x, p>>1);
if(p&1) return tmp*tmp%Mod*x%Mod;
else return tmp*tmp%Mod;
}
inline ll Inv(ll x){
return Quick_Power(x, Mod-2);
}
inline void Prework(){
while(N <= n+m) N <<= 1;
for(int i = log2(N)-1, t = 0; i >= 0; i--){
int x = t;
for(int j = 0; j <= x; j++) rev[++t] = (1<<i)|rev[j];
}
return;
}
inline void DFT(ll a[], ll v){
for(int i = 0; i < N; i++) if(rev[i] < i) swap(a[rev[i]], a[i]);
for(int i = 2; i <= N; i <<= 1)
for(int k = 0; k < i/2; k++){
ll w = Quick_Power(Rt, (Mod-1)/i*(i+k*v));
for(int j = 0; j < N; j += i){
ll x = a[j+k], y = a[i/2+j+k];
a[j+k] = (x+w*y)%Mod;
a[i/2+j+k] = (x-w*y%Mod+Mod)%Mod;
}
}
return;
}
inline void IDFT(ll a[]){
DFT(a, -1);
ll inv = Inv(N);
for(int i = 0; i < N; i++) (a[i] *= inv) %= Mod;
return;
}
int main(){
scanf("%d%d", &n, &m);
for(int i = 0; i <= n; i++) scanf("%lld", f+i);
for(int i = 0; i <= m; i++) scanf("%lld", g+i);
Prework();
DFT(f, 1), DFT(g, 1);
for(int i = 0; i < N; i++) h[i] = f[i]*g[i]%Mod;
IDFT(h);
for(int i = 0; i <= n+m; i++) printf("%lld ", h[i]);
return 0;
}
五、其他多项式运算
在有了多项式乘法以后,我们能够以其为跳板,实现更多多项式运算,包括多项式求逆、多项式 \(\ln\)、多项式 \(\exp\)、多项式开方……
你初次看到这里可能会比较困惑,因为对于我刚刚在上面列举的那些运算,似乎并不是所有的多项式都能存在一个“整”的运算结果。然而可以证明的是,除了无法执行上述运算的多项式(如常数项为 \(0\) 的多项式无法求逆),我们都能得到一个无限次数的多项式,满足为其执行运算的结果。
在 OI 中,我们求的是这个无限次数的多项式模 \(x^n\) 意义下的结果(即保留其前 \(n\) 项)。
1. 牛顿迭代法
牛顿迭代法是我们接下来要介绍的运算的基础。在学习之前,你需要先了解一些基础知识:泰勒展开。
形式化地,牛顿迭代法要解决的问题是:
给定一个函数 \(G(x)\),求出一个多项式 \(F(x)\),满足 \(G(F(x)) \equiv 0 \pmod{x^n}\)。
并且,设多项式 \(F^*\) 满足 \(G(F^*) = 0\)(即 \(F^*\) 为精确解),我们希望有 \(F^* \equiv F \pmod{x^n}\)。
牛顿迭代法的主要思路就是倍增:先算出模 \(x^{n/2}\) 意义下的结果,再借此得到模 \(x^n\) 意义下的结果。
设模 \(x^{n/2}\) 意义下的结果为 \(F_0\),当前的结果为 \(F\)。现在我们需要找到一个 \(F, F_0\) 之间的递归式。
然后是非常有意思的一步:将 \(G(F)\) 在 \(F_0\) 处泰勒展开。于是我们能得到:
接着,根据定义,有 \(F^* \equiv F \pmod{x^n}\) 和 \(F^* \equiv F_0 \pmod{x^{n/2}}\),故有:\(F \equiv F_0 \pmod{x^{n/2}}\)。
所以 \(F - F_0\) 的最低非零次项显然 \(\ge x^{n/2}\),故对于所有 \(k \ge 2\),\((F - F_0)^k\) 的最低非零次项 \(\ge x^n\),也就是说它们在模 \(x^n\) 意义下全部为 0。于是我们有:
转化为递推式的形式:
2. 多项式求逆
让我们来运用一下上面的递推式吧。
题目要求 \(F(x)A(x) \equiv 1 \pmod{x^n}\),我们不妨设 \(G(F(x)) = A(x)F(x) - 1\)(\(A(x)\) 看作常数)。
于是有递归式:
由于 \(G(F_0) \equiv 0 \pmod{x^{n/2}}\),所以这里的 \(A^{-1}\) 只需要在模 \(x^n / x^{n/2} = x^{n/2}\) 意义下成立。又因为 \(F_0 \equiv A^{-1} \pmod{x^{n/2}}\),所以有:
于是我们就可以倍增求出多项式 \(A\) 的逆了!
复杂度分析:有式子 \(T(n) = T(n/2) + O(n \log n)\),根据主定理,复杂度依旧为 \(O(n \log n)\)。
实现注意事项:当我们计算递归式时,我们会有次数为 \(n\) 的 \(A\)、次数为 \(n/2\) 的 \(F_0\),于是最终能得到次数为 \(2n\) 的乘积 \(F\)。然而我们并不能保留 \(F\) 的后 \(n\) 次项,即必须对 \(x^n\) 取模。(其实就算不取模,根据我们上面的推导过程,其答案并不会有误。问题在这样的话,下一轮递归时,\(F_0\) 就不是 \(n/2\) 的次数而是 \(n\) 次的,得到的 \(F\) 也就是 \(3n\) 次的,使得复杂度错误。)
点击查看代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MAXN = (1<<18)+5;
const ll Mod = 998244353, Rt = 3;
int rev[MAXN];
ll a[MAXN], b[MAXN], b1[MAXN], a1[MAXN];
inline ll Quick_Power(ll x, ll p){
if(!p) return 1;
ll tmp = Quick_Power(x, p>>1);
if(p&1) return tmp*tmp%Mod*x%Mod;
else return tmp*tmp%Mod;
}
inline void DFT(ll f[], int n, int sgn){
for(int i = log2(n)-1, j = 0; i >= 0; i--){
int t = j;
for(int k = 0; k <= t; k++) rev[++j] = rev[k]|(1<<i);
}
for(int i = 0; i < n; i++) if(rev[i] < i) swap(f[rev[i]], f[i]);
for(int i = 1; i < n; i <<= 1)
for(int k = 0; k < i; k++){
ll w = Quick_Power(Rt, (Mod-1)/i/2*(i*2+k*sgn));
for(int j = 0; j < n; j += i*2){
ll x = f[j+k], y = f[i+j+k];
f[j+k] = (x+w*y)%Mod;
f[i+j+k] = (x-w*y%Mod+Mod)%Mod;
}
}
if(sgn == -1){
ll inv = Quick_Power(n, Mod-2);
for(int i = 0; i < n; i++) (f[i] *= inv) %= Mod;
}
return;
}
int main(){
int n; scanf("%d", &n);
for(int i = 0; i < n; i++) scanf("%lld", a+i);
b[0] = Quick_Power(a[0], Mod-2);
for(int N = 2; N/2 < n; N <<= 1){
for(int i = 0; i < N; i++) a1[i] = a[i], b1[i] = b[i];
DFT(a1, N*2, 1), DFT(b1, N*2, 1);
for(int i = 0; i < N*2; i++) b1[i] = b1[i]*(2-a1[i]*b1[i]%Mod+Mod)%Mod;
DFT(b1, N*2, -1);
for(int i = 0; i < N; i++) b[i] = b1[i];
}
for(int i = 0; i < n; i++) printf("%lld ", b[i]);
return 0;
}
/*
B = B0 * (2 - A*B0)
*/
3. 多项式 \(\ln\)
这个问题的推导过程实际并不需要用到牛顿迭代法。
问题形如:
考虑对 \(F(x)\) 求导:
于是我们只需要对 \(A(x)\) 求导,算一个多项式逆,乘起来,在积分回去即可。
复杂度也为 \(O(n \log n)\)。
点击查看代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MAXN = (1<<18)+5;
const ll Mod = 998244353, Rt = 3;
int rev[MAXN];
ll a[MAXN], b[MAXN], b1[MAXN], a1[MAXN], c[MAXN];
inline ll Quick_Power(ll x, ll p){
if(!p) return 1;
ll tmp = Quick_Power(x, p>>1);
if(p&1) return tmp*tmp%Mod*x%Mod;
else return tmp*tmp%Mod;
}
inline void DFT(ll f[], int n, int sgn){
for(int i = log2(n)-1, j = 0; i >= 0; i--){
int t = j;
for(int k = 0; k <= t; k++) rev[++j] = rev[k]|(1<<i);
}
for(int i = 0; i < n; i++) if(rev[i] < i) swap(f[rev[i]], f[i]);
for(int i = 1; i < n; i <<= 1)
for(int k = 0; k < i; k++){
ll w = Quick_Power(Rt, (Mod-1)/i/2*(i*2+k*sgn));
for(int j = 0; j < n; j += i*2){
ll x = f[j+k], y = f[i+j+k];
f[j+k] = (x+w*y)%Mod;
f[i+j+k] = (x-w*y%Mod+Mod)%Mod;
}
}
if(sgn == -1){
ll inv = Quick_Power(n, Mod-2);
for(int i = 0; i < n; i++) (f[i] *= inv) %= Mod;
}
return;
}
inline void Inverse(ll a[], ll b[], int n){
b[0] = Quick_Power(a[0], Mod-2);
int N;
for(N = 2; N/2 < n; N <<= 1){
for(int i = 0; i < N; i++) a1[i] = a[i], b1[i] = b[i];
DFT(a1, N*2, 1), DFT(b1, N*2, 1);
for(int i = 0; i < N*2; i++) b1[i] = b1[i]*(2-a1[i]*b1[i]%Mod+Mod)%Mod;
DFT(b1, N*2, -1);
for(int i = 0; i < N; i++) b[i] = b1[i];
}
for(int i = n; i < N; i++) b[i] = 0;
return;
}
inline void Integrate(ll a[], int n){
for(int i = n; i >= 1; i--) a[i] = a[i-1]*Quick_Power(i, Mod-2)%Mod;
a[0] = 0;
return;
}
inline void Diff(ll a[], int n){
for(int i = 0; i < n; i++) a[i] = a[i+1]*(i+1)%Mod;
return;
}
int main(){
int n; scanf("%d", &n);
for(int i = 0; i < n; i++) scanf("%lld", a+i);
Inverse(a, b, n);
Diff(a, n);
int N = 1;
while(N < n*2) N <<= 1;
DFT(a, N, 1), DFT(b, N, 1);
for(int i = 0; i < N; i++) c[i] = a[i]*b[i]%Mod;
DFT(c, N, -1);
Integrate(c, n);
for(int i = 0; i < n; i++) printf("%lld ", c[i]);
return 0;
}
/*
积分 A'/A
*/
4. 多项式 \(\exp\)
问题形如:
我们先将其转化为 \(\ln F(x) \equiv A(x)\)。于是有 \(G(F(x)) = \ln F(x) - A(x)\)。
根据牛顿迭代法,有:
复杂度依旧为 \(O(n \log n)\)。
点击查看代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll Mod = 998244353, Rt = 3;
inline ll Quick_Power(ll x, ll p){
if(!p) return 1;
ll tmp = Quick_Power(x, p>>1);
if(p&1) return tmp*tmp%Mod*x%Mod;
else return tmp*tmp%Mod;
}
inline void DFT(vector<ll> &f, int n, int sgn){
f.resize(n);
vector<int> rev(n);
for(int i = log2(n)-1, j = 0; i >= 0; i--){
int t = j;
for(int k = 0; k <= t; k++) rev[++j] = rev[k]|(1<<i);
}
for(int i = 0; i < n; i++) if(rev[i] < i) swap(f[rev[i]], f[i]);
for(int i = 1; i < n; i <<= 1){
ll wn = Quick_Power(Rt, (Mod-1)/i/2);
for(int j = 0; j < n; j += i*2){
ll wk = 1;
for(int k = 0; k < i; k++){
ll x = f[j+k], y = f[i+j+k];
f[j+k] = (x+wk*y)%Mod;
f[i+j+k] = (x-wk*y%Mod+Mod)%Mod;
(wk *= wn) %= Mod;
}
}
}
if(sgn == -1){
reverse(f.begin()+1, f.begin()+n);//这是一个神秘的结论,我先咕了。
ll inv = Quick_Power(n, Mod-2);
for(int i = 0; i < n; i++) (f[i] *= inv) %= Mod;
}
return;
}
inline vector<ll> Inverse(vector<ll> a){
vector<ll> b(1);
b[0] = Quick_Power(a[0], Mod-2);
for(int N = 2; N <= a.size(); N <<= 1){
vector<ll> a1(a.begin(), a.begin()+N);
DFT(a1, N*2, 1), DFT(b, N*2, 1);
for(int i = 0; i < N*2; i++) b[i] = b[i]*(2-a1[i]*b[i]%Mod+Mod)%Mod;
DFT(b, N*2, -1);
b.resize(N);
}
return b;
}
inline void Integrate(vector<ll> &a){
for(int i = a.size()-1; i >= 1; i--) a[i] = a[i-1]*Quick_Power(i, Mod-2)%Mod;
a[0] = 0;
return;
}
inline void Diff(vector<ll> &a){
for(int i = 0; i+1 < a.size(); i++) a[i] = a[i+1]*(i+1)%Mod;
a[a.size()-1] = 0;
return;
}
inline vector<ll> Ln(vector<ll> a){
int n = a.size();
vector<ll> b = Inverse(a);
Diff(a);
DFT(a, n*2, 1), DFT(b, n*2, 1);
for(int i = 0; i < n*2; i++) b[i] = a[i]*b[i]%Mod;
DFT(b, n*2, -1);
b.resize(n), Integrate(b);
return b;
}
inline vector<ll> Exp(vector<ll> a){
vector<ll> b(1, 1);
for(int N = 2; N <= a.size(); N <<= 1){
vector<ll> a1(a.begin(), a.begin()+N), b1;
b.resize(N), b1 = Ln(b);
DFT(a1, N*2, 1), DFT(b1, N*2, 1), DFT(b, N*2, 1);
for(int i = 0; i < N*2; i++) b[i] = b[i]*(-b1[i]+a1[i]+1+Mod)%Mod;
DFT(b, N*2, -1);
b.resize(N);
}
return b;
}
int main(){
int n; scanf("%d", &n);
vector<ll> a(n);
for(int i = 0; i < n; i++) scanf("%lld", a.begin()+i);
int N = 1;
while(N < n) N <<= 1;
a.resize(N);
vector<ll> b = Exp(a);
// vector<ll> b = Ln(a);
for(int i = 0; i < n; i++) printf("%lld ", b[i]);
return 0;
}
/*
F(x) = F0(x) * (1 - ln F0(x) + A(x))
*/
六、其它卷积运算
1. 概述
我们回忆一下最经典的多项式卷积运算形式:
将 \(+\) 换成新运算 \(\cdot\),我们可以得到一些新的卷积:
直接枚举计算卷积的复杂度都是 \(O(n^2)\) 的。
接下来要介绍的几种算法的思路本质上都类似于 FFT:将原多项式 \(a, b, c\) 都转化为另一个数组 \(a', b', c'\),使得有 \(c'_i = a'_i b'_i\)(即转化为计算点积),并且该转化存在逆运算(允许我们将 \(c'\) 转化回 \(c\))。
2. \(\gcd / \operatorname{lcm}\) 卷积
即
考虑如何转化为点积。考虑一个名为狄利克雷前缀和的运算,其定义为:
以 \(\operatorname{lcm}\) 卷积为例(\(\gcd\) 卷积类似,不过需要换成狄利克雷后缀和),于是有:
狄利克雷前缀和可以通过枚举因数在 \(O(n \log n)\) 的时间复杂度内实现。狄利克雷差分的表达式不好直接写出,但是其代码就是狄利克雷前缀和的逆过程,十分好写。
如果要做到更快的复杂度,我们可以考虑对每一个质数依次做前缀和 / 差分,这样就相当于一个 \(p\) 维的高维前缀和。复杂度 \(O(n \log \log n)\)。
3. \(\operatorname{and} / \operatorname{or} / \operatorname{xor}\) 卷积(FWT / FMT)
容易发现,\(\operatorname{and} / \operatorname{or}\) 卷积与上面的 \(\gcd / \operatorname{lcm}\) 卷积的结构很像——本质上都是对元素在多个维度上取 \(\min / \max\)。所以我们其实可以直接通过做二进制数上的高维前/后缀和解决 \(\operatorname{and} / \operatorname{or}\) 的情况,复杂度 \(O(n \log n)\)。然而对于 \(\operatorname{xor}\) 卷积,这种想法并不能 fit in。
考虑一种新想法:设 \(w(i, j)\) 为 \(a_j\) 对 \(a'_i\) 的贡献系数,则有:
所以上式成立的一个充分不必要条件是 \(w(i, j) w(i, k) = w(i, j \operatorname{xor} k)\)。考虑能否根据此构造合适的 \(w\)。
并不容易发现的是:定义运算 \(x \circ y\) 表示 \(\operatorname{popcount}(x \operatorname{and} y) \bmod 2\),则 \((x \circ y) \operatorname{xor} (x \circ z) = x \circ (y \operatorname{xor} z)\)。
为了能更方便地逐位处理,我们不直接使用 \(\circ\) 运算定义 \(w\),而是令 \(w(i, j) = (-1)^{\operatorname{popcount}(x \operatorname{and} y)}\)。至于为什么不选 \(0/1\) 为底——因为这样 \(w\) 就不满足可逆性了。可能现在你还不太能理解“逐位处理”,但是往下看以后你就会明白了。
然后就是如何实现 \(a \rightarrow a'\) 的问题。与 FFT 类似,我们考虑二分处理问题,每次考虑将当前数组拆分为两个较小的数组。不一样的是,这里不是奇偶性分组,而是直接在中间断开分成两个。即:
观察发现,二分出来的两部分数仅有最高位不同,我们设计的 \(w\) 也可以拆为每一位的贡献之积。设 \(i_0\) 为 \(i\) 最高位的 \(0/1\) 值,\(i'\) 为 \(i\) 去掉最高位后的值,则有:
于是,设 \(a_0, a_1\) 为下一层递归中最高位分别为 \(0/1\) 的两个数组,则有:
显然已经可以很轻松地递归处理了,不过实际应用中我们一般会写成递推倍增形式。
其逆运算是什么?推导后不难发现:
(这个式子本质是 \(w(0/1, 0/1)\) 这个矩阵的逆。)
有趣的是:对于 \(\operatorname{and}, \operatorname{or}\) 运算,我们也可以通过设计贡献函数 \(w\) 的方式来思考。\(\operatorname{and}\) 的贡献函数为 \(w(i, j) = [i \operatorname{and} j = i]\),\(\operatorname{or}\) 的贡献函数为 \(w(i, j) = [i \operatorname{or} j = i]\)。接着可以用类似的倍增方式求解。
板题:P4717 【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT)
点击查看代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MAXN = (1<<17)+5;
const ll Mod = 998244353, inv = (Mod+1)/2;
int n, m;
ll a0[MAXN], b0[MAXN], a[MAXN], b[MAXN], c[MAXN];
inline void OR(ll a[], ll tp){
for(int i = 1; i < m; i <<= 1)
for(int j = 0; j < m; j++)
if(j&i) (a[j] += tp*a[j^i]) %= Mod;
return;
}
inline void AND(ll a[], ll tp){
for(int i = 1; i < m; i <<= 1)
for(int j = 0; j < m; j++)
if(j&i) (a[j^i] += tp*a[j]) %= Mod;
return;
}
inline void XOR(ll a[], ll tp){
for(int i = 1; i < m; i <<= 1)
for(int j = 0; j < m; j += i<<1)
for(int k = 0; k < i; k++){
ll x = a[j+k], y = a[i+j+k];
a[j+k] = (x+y)*tp%Mod, a[i+j+k] = (x-y+Mod)*tp%Mod;
}
return;
}
int main(){
scanf("%d", &n), m = 1<<n;
for(int i = 0; i < m; i++) scanf("%lld", a0+i);
for(int i = 0; i < m; i++) scanf("%lld", b0+i);
memcpy(a, a0, sizeof(a0)), memcpy(b, b0, sizeof(b0));
OR(a, 1), OR(b, 1);
for(int i = 0; i < m; i++) c[i] = a[i]*b[i]%Mod;
OR(c, Mod-1);
for(int i = 0; i < m; i++) printf("%lld ", c[i]); puts("");
memcpy(a, a0, sizeof(a0)), memcpy(b, b0, sizeof(b0));
AND(a, 1), AND(b, 1);
for(int i = 0; i < m; i++) c[i] = a[i]*b[i]%Mod;
AND(c, Mod-1);
for(int i = 0; i < m; i++) printf("%lld ", c[i]); puts("");
memcpy(a, a0, sizeof(a0)), memcpy(b, b0, sizeof(b0));
XOR(a, 1), XOR(b, 1);
for(int i = 0; i < m; i++) c[i] = a[i]*b[i]%Mod;
XOR(c, inv);
for(int i = 0; i < m; i++) printf("%lld ", c[i]); puts("");
return 0;
}
4. 子集卷积
即:
显然 \(i \operatorname{or} j = k\) 这个条件可以直接使用 FWT 解决。问题在 \(i \operatorname{and} j = 0\) 怎么办。
\(i \operatorname{and} j = 0\) 其实可以等价为 \(\operatorname{popcount}(i) + \operatorname{popcount}(j) = \operatorname{popcount}(i \operatorname{or} j)\)。故我们可以为卷积数组增加一维表示 \(\operatorname{popcount}\),即设 \(A_{i, j} = \begin{cases} a_j & \operatorname{popcount}(j) = i \\ 0 & \operatorname{popcount}(j) \ne i \end{cases}\)。总时间复杂度为 \(O(n \log^2 n)\)。
点击查看代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MAXN = 25, MAXM = (1<<20)+5;
const ll Mod = 1e9+9;
int n, m, ppc[MAXM];
ll a[MAXN][MAXM], b[MAXN][MAXM], c[MAXN][MAXM];
inline void FWT(ll a[], ll tp){
for(int i = 1; i < m; i <<= 1)
for(int j = 0; j < m; j++)
if(j&i) (a[j] += tp*a[j^i]) %= Mod;
return;
}
int main(){
scanf("%d", &n), m = 1<<n;
for(int i = 0; i < m; i++) ppc[i] = __builtin_popcount(i);
for(int i = 0; i < m; i++) scanf("%lld", &a[ppc[i]][i]);
for(int i = 0; i < m; i++) scanf("%lld", &b[ppc[i]][i]);
for(int i = 0; i <= n; i++) FWT(a[i], 1), FWT(b[i], 1);
for(int i = 0; i <= n; i++){
for(int j = 0; j <= i; j++)
for(int k = 0; k < m; k++) (c[i][k] += a[j][k]*b[i-j][k]) %= Mod;
FWT(c[i], Mod-1);
}
for(int i = 0; i < m; i++) printf("%lld ", c[ppc[i]][i]);
return 0;
}

浙公网安备 33010602011771号