多项式优化技术基础
本文的推导需要一定的高中数学基础.
ref:https://www.cnblogs.com/pks-t/p/9251147.html
卷积
卷积是一种通过两个函数 \(f,g\) 生成第三个函数的数学算子.
在离散意义下若满足:
则称 \(C(x)\) 为 \(A(x)\) 与 \(B(x)\) 的卷积.
当 \(A(x)\) 和 \(B(x)\) 为 \(n\) 次的多项式,不难发现 \(C(x)\) 为 \(2n\) 次多项式,即有 \(2n+1\) 项.
如果你实在不理解,实际上卷积就是多项式乘法换了个名字.
点值表示法
我们常见的多项式由系数表示,也就是:
大家都知道 \(n+1\) 个点可以唯一确定一个 \(n\) 次函数,由此引出以 \(n+1\) 个点值来表示 \(n\) 次多项式的方法.
具体地,我们构建 \(i\rightarrow f_i\) 的映射,那么有:
点值优化卷积
假设我们有两个关于 \(x\) 的 \(n\) 次多项式 \(A(x),B(x)\),我们利用点值来尝试对 \(A(x),B(x)\) 进行卷积运算.
从函数的角度来理解,两函数相乘,得到的结果的点值就是两函数对应点值之积. 可以把这个当作基本事实,必要性显然. 于是我们就可以进一步用点值来优化多项式卷积.
由于 \(A(x)\cdot B(x)\) 结果是 \(2n\) 次,有 \(2n+1\) 项,所以我们分别取 \(A(x),B(x)\) 的 \(2n+1\) 个点值,并对
做乘法可得
复杂度由暴力卷积 \(O(n^2)\) 优化为了 \(O(n)\),极大的提升了效率,但是由系数表达式和点值表达式的相互转换暴力仍然是 \(O(n^2)\) 的.
这种将系数转换为点值求卷积的算法叫做傅里叶变换,主要包括两部分 DFT(离散傅里叶变换)和 IDFT(离散傅里叶逆变换). 其中 DFT 对应的是系数表达式转点值表达式的过程,IDFT 对应的是点值表达式转系数表达式的过程.
单位根
现在我们的优化聚焦于系数表达式与点值表达式的相互转换. 而利用单位根的性质,可以将 DFT 和 IDFT 都优化到 \(O(n\log n)\).
下面我们来看看如何优化.
定义
如果读者有高中数学基础那么这部分应该是容易理解的. 定义 \(n\) 次单位根是满足
的 \(n\) 个复数解,这 \(n\) 个复数均匀分布在复平面的单位圆上.
考虑欧拉公式,即三角函数的指数表示法:
定义主次单位根为 \(i=1,x=2\pi\) 时的单位根,也就是:
容易得知,其它所有单位根都可以用主次单位根的整数次幂表出.
单位根的基本运算这里就不再赘述. 下面我们来看几个引理.
Lemma 1 消去引理
\(\forall n,k\in \text N\wedge d\in\text N^+,\omega^{dk}_{dn}=\omega_{n}^k\)
直接带入欧拉公式易证,分子分母同时消去了 \(d\).
Lemma 2 折半引理
\(\forall x\in N^+\wedge n=2x,\{\omega_n^{2i}|i\in \text N,i< n\}=\{\omega_{n\over2}^i|i\in \text N,i<{n\over 2}\}\)
由引理 1 易证. 从复平面的角度还可以理解为 \(n\) 是偶数时共线的两个单位根的平方是相等的.
这是单位根优化 DFT,IDFT 最核心的一个引理. 它指向了折半/分治/迭代 等思想.
Lemma 3 求和引理
\(\forall n\in N^+\wedge n\nmid k,\sum\limits_{i=0}^{n-1}\left(\omega_n^k\right)^i=0\)
由等比数列求和公式可得:
带入 \(x=\omega_n^k\),即:
由于保证了 \(n\nmid k\) 所以分母不为 \(0\),即原式为 \(0\).
特别的,当 \(n\mid k\) 时原式值为 \(n\).
快速傅里叶变化
利用单位根的性质对 DFT/IDFT 进行优化的算法称为 FFT(快速傅里叶变换). 其可将时间复杂度优化至 \(O(n\log n)\),主要用到的性质就是上文证明的折半引理.
FFT 优化 DFT
考虑将多项式
按下标奇偶拆成两个次数减半的多项式 \(A^{[0]}(x),A^{[1]}(x)\):
就可以得到下式:
那我们求 \(A(x)\) 的 \(n\) 个点值,就转换成分别求 \(A^{[0]}(x^2),A^{[1]}(x^2)\) 的 \({n\over 2}\) 个点值,再根据上式进行合并即可.
这样直接求仍然是 \(O(n)\) 的,但是这个 \(x\rightarrow x^2\) 的转化启发我们用折半引理的结论对数据规模进行简化.
于是考虑求
处的点值,每次平方数据规模都可以减半,就可以做到只求 \(O(\log n)\) 次而不是 \(O(n)\) 次. 具体的:
其中下式可以进行化简:
即有:
所以每次数据计算数据规模都可以翻倍,只用求 \(O(\log n)\) 次点值,DFT 的时间复杂度就来到了 \(O(n\log n)\).
FFT 优化 IDFT
IDFT 是 DFT 的逆变换.
现在我们已经将系数表达式转换为点值表达式,并且计算出卷积的点值,考虑怎么快速通过点值还原出系数表达式. 由于单位根具有周期性,不妨带入 \(-\omega_n^k\) 试着计算一下. 具体的,设对 \(A(x)\) 进行 DFT 之后得到的点值表示为
构造多项式
带入点值
计算一番
根据求和引理,仅当 \(j-k=0\) 时右式值为 \(n\),其余项全为 \(0\). 所以最后得到:
我们只需要仿照 DFT,带入相反的点值,得到的结果再除以 \(n\) 就可以在同样 \(O(n\log n)\) 的时间复杂度把点值表达式还原成系数表达式了.
蝴蝶变换
上述算法可以使用递归分治简单实现,但是由于涉及三角函数、虚数、递归等等常数很大. 考虑寻找下标变换的规律:
由于每次每组的都按照组内序号的奇偶性再分组,可以观察到:第 \(i\) 轮分组从低到高第 \(i\) 位为 \(0\) 的会被分到左边的组,为 \(1\) 的会被分到右边的组.
可以发现,若有 \(2^m\) 个数进行递归分治,下标为 \(2^{m-1}\) 的最后总是会排到下标为 \(1\) 的位置,下标为 \(2^{m-2}\) 的最后总是会排到 \(2\) 的位置. 一个事实是每个下标最终的位置是最初下标的二进制位对称翻转得到的,也就是 \(m\) 位二进制下标经过变换会得到
所以我们可以预处理变换数组,会大大减小代码的常数.
代码实现
有一个常数优化是将第二个多项式存入第一个多项式的虚部,求第一个多项式的平方,由于
所以答案就是虚部取出来除以二,注意要四舍五入.
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 3e6 + 10;
const double pi = acos(-1.0);
int n, m;
int lim = 1, lg, rev[maxn];
struct cmplx{
double x, y;
cmplx (double a = 0, double b = 0) {x = a, y = b;}
cmplx operator + (cmplx C) {return cmplx(x + C.x, y + C.y);}
cmplx operator - (cmplx C) {return cmplx(x - C.x, y - C.y);}
cmplx operator * (cmplx C) {return cmplx(x * C.x - y * C.y, x * C.y + y * C.x);}
}A[maxn];
inline void FFT(cmplx *P, double flag) {
for(int i = 0; i < lim; i++) if(i < rev[i]) swap(P[i], P[rev[i]]);
for(int i = 1; i < lim; i <<= 1) {
cmplx wi(cos(pi / i), flag * sin(pi / i));
for(int j = 0; j < lim; j += (i << 1)) {
cmplx wn(1.0, 0.0);
for(int k = 0; k < i; k++, wn = wn * wi) {
cmplx w1 = P[j + k], w2 = P[j + k + i];
P[j + k] = w1 + wn * w2;
P[j + k + i] = w1 - wn * w2;
}
}
}
if(flag == -1) {
for(int i = 0; i <= lim; i++) P[i].x /= lim, P[i].y /= lim;
} return;
}
int main() {
ios :: sync_with_stdio(false); cin.tie(0); cout.tie(0);
cin >> n >> m;
for(int i = 0; i <= n; i++) cin >> A[i].x;
for(int i = 0; i <= m; i++) cin >> A[i].y;
while(lim <= max(n, m) * 2) lim <<= 1, lg++;
for(int i = 0; i < lim; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
FFT(A, 1);
for(int i = 0; i <= lim; i++) A[i] = A[i] * A[i];
FFT(A, -1);
for(int i = 0; i <= n + m; i++) cout << (int)(A[i].y / 2 + 0.5) << " ";
return 0;
}
附件:多项式全家桶
进队之前不要学,直接拿来用就可以了.
ref:https://www.cnblogs.com/EmilyDavid/p/18710389
#include<bits/stdc++.h>
using namespace std;
using i64=long long;
constexpr int mo=998244353,maxn=1<<21;
inline int add(int x,int y){x+=y;return x<mo?x:x-mo;}
inline void upd(int &x,int y){x=add(x,y);return;}
int ksm(int x,int y){int rs=1;for(;y;y>>=1,x=(i64)x*x%mo) if(y&1) rs=(i64)rs*x%mo;return rs;}
int inv[maxn];
void initinv(){
inv[1]=1;for(int i=2;i<maxn;++i) inv[i]=(i64)(mo-mo/i)*inv[mo%i]%mo;
}
int rev[maxn];
inline int getmi(int n){int len=1;while(len<n) len<<=1; return len;}
int initrev(int n){
int len=getmi(n);
for(int i=1;i<len;++i){
rev[i]=rev[i>>1]>>1;
if(i&1) rev[i]|=len>>1;
}
return len;
}
struct Poly{
vector<int> v;
int L;
Poly(){vector<int>().swap(v);L=0;}
int& operator[](int x){
return v[x];
}
void slice(int len){ v.resize(len);L=len;}
void NTT(int len,bool typ){
for(int i=1;i<len;++i) if(i<rev[i]) swap(v[i],v[rev[i]]);
for(int h=2;h<=len;h<<=1){
int w=ksm(3,(mo-1)/h);
for(int i=0;i<len;i+=h){
int wn=1;
for(int j=i;j<i+h/2;++j){
int x=v[j],y=(i64)wn*v[j+h/2]%mo;
v[j]=add(x,y);v[j+h/2]=add(x,mo-y);
wn=(i64)wn*w%mo;
}
}
}
if(!typ){
reverse(v.begin()+1,v.end());
for(int i=0;i<len;++i) v[i]=(i64)v[i]*inv[len]%mo;
}
}
Poly operator*(const Poly &F){
Poly rs,f,g;f=*this;g=F;
int len=initrev(L+F.L-1);
f.slice(len);g.slice(len);rs.slice(len);
f.NTT(len,1);g.NTT(len,1);
for(int i=0;i<len;++i) rs.v[i]=(i64)f.v[i]*g.v[i]%mo;
rs.NTT(len,0);
return rs;
}
Poly Inv(){
Poly g,g0,f;
int Lim=getmi(L);
g.slice(1);g[0]=ksm(v[0],mo-2);
for(int len=2;(len>>1)<Lim;len<<=1){
f=*this;f.slice(len);
len<<=1;
f.slice(len);g.slice(len);g0=g;
initrev(len);
g0.NTT(len,1);f.NTT(len,1);
for(int i=0;i<len;++i) g[i]=(i64)g0[i]*add(2,mo-(i64)f[i]*g0[i]%mo)%mo;
g.NTT(len,0);
len>>=1;
g.slice(len);
}
g.slice(L);
return g;
}
Poly Der(){
Poly rs;rs.slice(L-1);
for(int i=1;i<=L-1;++i) rs[i-1]=(i64)v[i]*i%mo;
return rs;
}
Poly Int(){
Poly rs;rs.slice(L+1);rs[0]=0;
for(int i=1;i<=L;++i) rs[i]=(i64)v[i-1]*inv[i]%mo;
return rs;
}
Poly Ln(){
Poly rs=(this->Der())*(this->Inv());
rs=rs.Int();rs.slice(L);
return rs;
}
Poly Exp(){
Poly f,g,g0,gln;
int Lim=getmi(L);
g.slice(1);g[0]=1;
for(int len=2;(len>>1)<Lim;len<<=1){
f=*this;f.slice(len);
g.slice(len);gln=g.Ln();gln.slice(len);
len<<=1;
f.slice(len);gln.slice(len);g.slice(len);g0=g;
initrev(len);
g0.NTT(len,1);f.NTT(len,1);gln.NTT(len,1);
for(int i=0;i<len;++i) g[i]=(i64)g0[i]*add(1,add(mo-gln[i],f[i]))%mo;
g.NTT(len,0);
len>>=1;
g.slice(len);
}
g.slice(L);
return g;
}
}F,ans;

浙公网安备 33010602011771号