FFT学习笔记

很早就想学FFT那套理论,但抱着能咕一天是一天的态度咕到了今天

 

fft是干什么的?

求两个多项式卷积的,比如$g=a*b$($g_x=\sum{a_i*b_{x-i}}$)

显然暴力乘是$O(n^2)$的,然而我们可以把他优化到$O(n\;log\;n)$

一般来将,多项式是用每一项的系数表示的,而还可以用点值来表示,比如一个多项式$a$有n项,我们可以让变量x取n个不同的值,然后用n个得出来的值表示这个多项式

由于是点值,两个多项式相乘时只要把对应的点值相乘即可,这是$O(n)$的

显然我们容易可以把两个表示法互相转换,比如$n^2$暴力和$n^3$高斯消元

这似乎比暴力还要不优越,所以我们要优化转化表示法的复杂度

我们发现我们可以随便选数x,只要x各不相同就行,然而当x在大部分取值时,都要暴力算$x^i$的值,这很不优越

所以我们的数学知识告诉我们数不只有实数,还有虚数啊懒得介绍虚数,我去拖一点东西过来

虚数大概就是可以表示在一个复平面上的东西,k次单位根$w_k$就是其k次方是1的东西$w_k$的i次方都在一个以原点为圆心,1为半径的圆上

由于复数运算的种种性质(幅角相加,长度相乘),这些东西是绕原点顺时针的,然后我们把i次单位根代入多项式来求值,这就是可以优化的了

 然后我们就可以用分治来优化啦

把多项式分成奇数项和偶数项两部分然后分治,就像这样

愉快地盗了张图来 原文

由于要按照奇偶性分治,分治后的顺序会和原来不同,为了让实现更方便,可以把原下标的二进制翻转后当新下标(不会证)

 

那么我们就可以把系数表示变成点值表示了,这就是DFT

然后怎么把点值还原成系数呢,据说只要把点值除第一项的、部分翻转,然后跑DFT,再把结果除以n就好了(还是不会证)

 

大概就这样吧我感觉讲的海星

例题有这个(模板)多项式乘法(FFT)

AC代码

#include<bits/stdc++.h>
using namespace std;
#define int long long
int n,m,i,j,len;
int w[5000000],f[5000000],g[5000000];
int rev[5000000],bit,ha=998244353;
inline void getrev(int n){for(i=0;i<(1<<n);i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));}
int add(int x){return (x>=ha)?x-ha:x;}
int jian(int x){return (x<0)?x+ha:x;}
inline int qpow(int a,int b)
{
    int ans=1;
    while(b){if(b%2)ans=ans*a%ha;a=a*a%ha,b/=2;};
    return ans;
}
void fft(int *a,int n,int x)
{
    if(x)reverse(a+1,a+n);
    for(int i=0;i<n;i++)if(rev[i]>i)swap(a[i],a[rev[i]]);
    for(int i=1;i<n;i<<=1)
    {
        w[0]=1; w[1]=qpow(3,(ha-1)/(i<<1));
        for(int j=2;j<i;j++)w[j]=w[j-1]*w[1]%ha;
        for(int j=0;j<n;j+=(i<<1))
            for(int k=j;k<j+i;k++)
            {
                int x=a[k],y=a[k+i]*w[k-j]%ha;
                a[k]=add(x+y),a[k+i]=jian(x-y);
            }
    }
    int ni=qpow(n,ha-2);
    if(x)for(int i=0;i<n;i++)a[i]=a[i]*ni%ha;
}
signed main()
{
    scanf("%lld%lld",&n,&m);
    for(i=0;i<=n;i++)scanf("%lld",&f[i]);
    for(i=0;i<=m;i++)scanf("%lld",&g[i]);
    len=1;
    for(;len<m+n+1;len=len<<1)bit++;getrev(bit);
    fft(f,len,0);fft(g,len,0);
    for(i=0;i<len;i++)f[i]=f[i]*g[i]%ha;
    fft(f,len,1);
    for(i=0;i<=n+m;i++)printf("%lld ",f[i]);
    return 0;
}
View Code

 

fft还有模意义下的版本NTT

只要把单位根变成模数的原根就行了,一般都是3,如998244353(1e9+7不是NTT模数)

如果要求没有原根的多项式乘法,可以用CRT把有原根的答案结合起来

再来道例题

AC代码

#include<bits/stdc++.h>
using namespace std;
#define int long long
int n,m,i,j,len;
int w[2100000],f[2100000],g[2100000];
int rev[2100000],bit,ha=998244353;
inline void getrev(int n){for(i=0;i<(1<<n);i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));}
int add(int x){return (x>=ha)?x-ha:x;}
int jian(int x){return (x<0)?x+ha:x;}
inline int qpow(int a,int b)
{
    int ans=1;
    while(b){if(b%2)ans=ans*a%ha;a=a*a%ha,b/=2;};
    return ans;
}
void fft(int *a,int n,int x)
{
    if(x)reverse(a+1,a+n);
    for(int i=0;i<n;i++)if(rev[i]>i)swap(a[i],a[rev[i]]);
    for(int i=1;i<n;i<<=1)
    {
        w[0]=1; w[1]=qpow(3,(ha-1)/(i<<1));
        for(int j=2;j<i;j++)w[j]=w[j-1]*w[1]%ha;
        for(int j=0;j<n;j+=(i<<1))
            for(int k=j;k<j+i;k++)
            {
                int x=a[k],y=a[k+i]*w[k-j]%ha;
                a[k]=add(x+y),a[k+i]=jian(x-y);
            }
    }
    int ni=qpow(n,ha-2);
    if(x)for(int i=0;i<n;i++)a[i]=a[i]*ni%ha;
}
signed main()
{
    scanf("%lld%lld",&n,&m);
    for(i=0;i<=n;i++)scanf("%lld",&f[i]);
    for(i=0;i<=m;i++)scanf("%lld",&g[i]);
    len=1;
    for(;len<m+n+1;len=len<<1)bit++;getrev(bit);
    fft(f,len,0);fft(g,len,0);
    for(i=0;i<len;i++)f[i]=f[i]*g[i]%ha;
    fft(f,len,1);
    for(i=0;i<=n+m;i++)printf("%lld ",f[i]);
    return 0;
}
View Code

 

posted @ 2018-08-06 15:14  橙子用户  阅读(274)  评论(0编辑  收藏  举报