芝士:FFT & NTT
FFT
背景
对于两个多项式乘法\(O(n^2)\)的时间复杂度难以令人满意
所以将其优化,得到了FFT算法,其时间复杂度为优秀的\(O(n*log_n)\)
前置芝士
系数表示法
对于一个\(n+1\)项的\(n\)次多 ♂ 项式\(f(x)=\sum_{i=0}^{n} a_i*x^i\)
也可如此表示\(f(x)=\{a_0,a_1,a_2,·····,a_{n}\}\)
点值表示法
对于一个函数,很明显我们可以将其与平面直角坐标系相联系
我们回忆一下
对于一次函数,有了两个点就能确定唯一一个的一次函数
对于二次函数,有了三个点就能确定唯一一条二次函数
当然,也可以是更低次函数上的点
那么对于一个\(n\)次函数,有了\(n+1\)个点,就能确定一条唯一的\(n\)次函数
所以\(f(x)=\{\{x_1,y_1\},\{x_2,y_2\},·····,\{x_{n+1},y_{n+1}\}\}\)
复数
首先我们知道\(i=\sqrt{-1}\)
形如\(a+b*i\)的形式成为虚数
其中\(a\)被称为实部,\(b*i\)成为虚部
那么运算呢?
设\(x=a_1+b_1*i,y=a_2+b_2*i\)
\(\begin{cases}x+y=(a_1+a_2)+(b_1+b_2)*i\\x*y=(a_1*a_2-b_1*b_2)+(a_1*b_2+a_2*b_1)*i\end{cases}\)
欧拉公式
\(e^{ix}=cos x+isinx\)
证明:咕咕咕
单位复数根
实际上因为是\(\omega\),因为写着方便就直接\(w\)了
\(w^n=1\),运用欧拉公式,则有
\(w^n=1=1+i*0=cos(2k\pi)+isin(2k\pi)=e^{i2k\pi},k\in Z\)
\(w=e^\frac{i2k\pi}{n}\)
如果需要得到多项,只需要改变\(k\)即可
\(w_k=e^{\frac{i2k\pi}{n}},k\in[0,n-1]\)
接下来只需要考虑这\(n\)个\(w\)不相同即可
\(w_k=e^{\frac{i2k\pi}{n}}=cos(\frac{2k\pi}{n})+isin(\frac{2k\pi}{n})\)
在复数坐标系下,实际上这个东西就是将单位圆均分成\(n\)分,故一定不相同
对\(w\)有一些性质
\(w_{2n}^{2k}=w_n^k\),考虑\(w\)的定义就可以证明
\(w_n^k=-w_n^{k+\frac{n}{2}}\),可以通过定义,暴力变形\(cosx+isinx\)就可以证明
进入正题
首先我们要明白我们要的是什么,
\(h(x)=f(x)*g(x)\),
要求的即为\(h(x)\),
设\(f\)有\(n\)项,\(g\)有\(m\)项,那么显然\(h\)有\(n+m-1\)项
也就是指,如果我们知道\(h\)函数计算出来的\(n+m\)个坐标,那么就能确定唯一的\(h\)值
系数\(\rightarrow\)点值
考虑一些奇奇怪怪的分解
\(f(x)=\sum_{i=0}^{n-1}a_ix^i\)
设\(f_1(x)=\sum_{i=0}^{\frac{n-1}{2}}a_{2i}x_i\),\(f_2(x)=\sum_{i=0}^{\frac{n-2}{2}}a_{2i+1}x^i\)
通俗一点,就是将偶数项的系数和奇数项的系数分别构成了两个函数
那么就有\(f(x)=f_1(x^2)+xf_2(x^2)\)
然后我们对\(f_1(x)\)和\(f_2(x)\)继续进行这样的分解,直到只有常数项
这个就很像归并排序的过程,所以算\(f(x)\)的时间复杂度为\(O(nlog_n)\)!!!,
成功地将\(O(n)\)的算法优化到了\(O(nlog_n)\)
如果我们将\(x_i=w_i\),
一般的,有\(f(w_n^k)=f_1(w_n^{2k})+w_n^kf_2(w_{n}^{2k})\)
然后,这不就可以从\(f(w_n^k)\rightarrow f(w_n^{k+\frac{n}{2}})\)
也就是说针对每一层可以直接\(O(n)\)算出所有的$f(w_n^k) $
时间复杂度还是\(O(nlog_n)\)
点值\(\rightarrow\)系数
根据一大堆推导
我们可以得到结论\(a_k=\frac{1}{n}\sum_{i=0}^{n-1}b_iw^{ki}\)
其中\(b_i\)为点值表达式的值
也就是说定义函数\(A(x)=\frac{1}{n}\sum_{i=0}^{n-1}b_ix^i\),
需要求出\(x=w^k,k\in[0,n-1]\)处的值
是不是感觉这个问题就可以用已经推导出来的\(FFT\)求解?
所以时间复杂度还是\(O(nlog_n)\)
常数优化(真的是常数)
非递归
真的如标题所示,将递归的写成非递归,就是将数组进行划分就行了
蝴蝶优化
将\(\omega\)写成一个常量,同时新定义一个变量作为中介变量就行了
板子
#include<iostream>
#include<algorithm>
#include<cmath>
using namespace std;
#define maxn 10000005
#define x first
#define y second
const double pi=acos(-1.0);
int n,m;
int limit=1;
pair<double,double> a[maxn];
pair<double,double> b[maxn];
int l;
int r[maxn];
pair<double,double> operator + (pair<double,double> a,pair<double,double> b)
{
return make_pair(a.x+b.x,a.y+b.y);
}
pair<double,double> operator - (pair<double,double> a,pair<double,double> b)
{
return make_pair(a.x-b.x,a.y-b.y);
}
pair<double,double> operator * (pair<double,double> a,pair<double,double> b)
{
return make_pair(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}
#undef x
#undef y
void fft(pair<double,double> *a,int t)
{
for(int i=0;i<limit;i++)
if(i<r[i])
swap(a[i],a[r[i]]);
for(int mid=1;mid<limit;mid<<=1)
{
pair<double,double> wn=make_pair(cos(pi/mid),t*sin(pi/mid));
for(int r=mid<<1,j=0;j<limit;j+=r)
{
pair<double,double> w=make_pair(1,0);
for(int k=0;k<mid;k++,w=w*wn)
{
pair<double,double> x=a[j+k],y=w*a[j+mid+k];
a[j+k]=x+y;
a[j+mid+k]=x-y;
}
}
}
}
void work()
{
fft(a,1);
fft(b,1);
for(int i=0;i<=limit;i++)
a[i]=a[i]*b[i];
fft(a,-1);
for(int i=0;i<=n+m;i++)
cout<<(int)(a[i].first/limit+0.5)<<' ';
}
void prepare()
{
while(limit<=n+m)
{
limit<<=1;
l++;
}
for(int i=0;i<limit;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
}
int main()
{
ios::sync_with_stdio(false);
cin>>n>>m;
for(int i=0;i<=n;i++)
cin>>a[i].first;
for(int i=0;i<=m;i++)
cin>>b[i].first;
prepare();
work();
return 0;
}
NTT
背景
因为FFT的精度损失会影响答案,
导致许多Oier心态@#!@#^&$
所以NTT横空出世
前置芝士
阶
若\(gcd(i,p)=1并且p>1\)
那么对于\(a^r\equiv 1\pmod p\)的最小的\(r\),将\(r\)称作模\(p\)的阶,记作\(\delta_p(a)\)
原根
对于正整数\(p\),若有整数\(a\),满足\(\delta_p(a)=\varphi(p)\)。我们称\(a\)为\(p\)的原根
若\(p\)有原根,那么\(\varphi(\varphi(p))\)个原根
模仿
我们思考FFT的精度损失为什么会怎么大
其实就是因为\(\omega\)无法准确的表示,
所以我们考虑用原根代替\(\omega\),
同时原根满足\(\omega\)的性质,因为FFT的优化全是因为有\(\omega\)这个奇妙的东西
于是我们这样定义原根
别问我为什么这样定义
\(\omega_n=(g^{\frac{p-1}{n}})\%p\)
其中g是质数p的原根,
其他的就与FFT一样的,
那么问题来了,\(\omega_n\)是否能像\(FFT\)的\(\omega\)一样满足那些式子?
\(w_{n}^k\equiv w_{2n}^{2k}\pmod p\),这个应该很显然吧。。。。
\(w_n^k\equiv -w_n^{k+\frac{n}{2}} \pmod p \Leftrightarrow1+g^{\frac{p-1}{2}}\equiv 0\pmod p\)
然后有\(g^{\frac{p-1}{2}}\equiv g^{\frac{p-1}{2}}(g^{\frac{p-1}{2}})^{p-1} \equiv g^{\frac{p(p-1)}{2}}\pmod p\),费马小定理
之后有\(g^{\frac{p(p-1)}{2}}\equiv \prod_{i=0}^{p-1} g^i\equiv -1\pmod p\),威尔逊定理
之后就有\(1+g^{\frac{p-1}{2}}\equiv 0\pmod p\),故性质2成立
p
p一般取998244353,1004535809,469762049,这三个数的原根都是3
代码
#include<iostream>
using namespace std;
const int g=3,gi=332748118;
const int mod=998244353;
int n,m;
long long F[4000005],G[4000005];
int limit=1;
int l,r[4000005];
long long qkpow(int a,int b)
{
if(b==0)
return 1;
if(b==1)
return a;
long long t=qkpow(a,b/2);
t=t*t%mod;
if(b&1)
t=t*a%mod;
return t;
}
void prepa(int n)
{
while(limit<=n)
{
limit*=2;
l++;
}
for(int i=0;i<=limit;i++)
r[i]=((r[i>>1]>>1)|((i&1)<<(l-1)));
}
void ntt(long long *a,int ty)
{
for(int i=0;i<limit;i++)
if(i<r[i])
swap(a[i],a[r[i]]);
for(int mid=1;mid<limit;mid<<=1)
{
long long wn=qkpow(ty?g:gi,(mod-1)/(mid<<1));
//cout<<wn<<' '<<'\n';
for(int r=mid<<1,j=0;j<limit;j+=r)
{
long long w=1;
for(int k=0;k<mid;k++,w=w*wn%mod)
{
long long x=a[j+k],y=w*a[j+k+mid]%mod;
a[j+k]=(x+y)%mod;
a[j+k+mid]=((x-y)%mod+mod)%mod;
}
}
}
}
int main()
{
ios::sync_with_stdio(false);
cin>>n>>m;
for(int i=0;i<=n;i++)
cin>>F[i];
for(int i=0;i<=m;i++)
cin>>G[i];
prepa(n+m);
ntt(F,1);ntt(G,1);
for(int i=0;i<=limit;i++)
F[i]=F[i]*G[i]%mod;
ntt(F,0);
int inv=qkpow(limit,mod-2);
for(int i=0;i<=n+m;i++)
cout<<F[i]*inv%mod<<' ';
return 0;
}
限制
NTT的限制也是很明显的
1.系数必须为整数
2.模数必须为质数
3.令\(p=2^m*k+1\),k为奇数,则多项式的长度必须\(n\le2^m\)
对于模数不为质数的情况
n次的多项式在模m下乘积,最终的系数不会超过\(n*m^2\)
所以我们做三次NTT就行了
为什么不多做几次呢?
废话,多做几次你的常数就上去了,并且你还要套CRT
所以我们找三个模数就行了(一般情况。。。)
但是可能会爆long long
所以我们需要使用一些技巧
\(\begin{cases}x\equiv a_1\pmod {m_1}\\x\equiv a_2\pmod {m_2}\\x\equiv a_3 \pmod{m_3}\end{cases}\)
我们可以在long long的范围内合并前两个
\(\begin{cases}x\equiv A\pmod{M}\\x\equiv a_3\pmod{m_3}\end{cases}\)
所以最后的答案为
\(ans=kM+A\)
且k需要满足
\(kM+A\equiv a_3\pmod{m_3}\)
因为k是在模\(m_3\)的意义下求出的,所以k必然满足
\(k\equiv(a_3-A)*M^{-1}\pmod{m_3}\)
求出k之后就可以求出ans了
代码
咕掉了