ANJHZ的博客

2021.08.14 fft(快速傅里叶变换)

主要就是把平时用的系数表示法转换成点值表示法。点值表示法的好处就在于两个多项式相乘的时候可以O(n)运算。

取一般的点把系数表示法转换为点值表示法是O(n2)的,我们选取单位根可以避免这个问题。

先把f(x)补成次数为2的次幂-1的多项式,再根据次数奇偶分成两个多项式:

(这里f2(x)写错了,应该是xf2(x))

g都是n/2-1次的多项式,代入的单位根也变成了以n/2为底的,因此问题变成了原来的一半。我们把f(x)的系数重排,使得前半部分为g1的系数,后半部分为g2的系数,递归下去处理即可。时间复杂度为O(nlogn)。

考虑如何把点值表示法转换为系数表示法(IFFT):

因此逆处理的时候带入w(n,-k),最后再把系数全部除以n即可。

于是我们有如下版本的递归fft:

#include <bits/stdc++.h>
using namespace std;
const int N=4e6+11;
const double pi=acos(-1);
int n,m;
struct cp
{
    double x,y;
    cp operator +(cp q){return (cp){x+q.x,y+q.y};}
    cp operator -(cp q){return (cp){x-q.x,y-q.y};}
    cp operator *(cp q){return (cp){x*q.x-y*q.y,x*q.y+y*q.x};}
}a[N],b[N],a1[N],a2[N],c[N],tmp[N]; 
void fft(cp* t,int len,int rev)
{ 
    if(len==1) return;
    int i,tot1=0,tot2=0;
    for(i=0;i<len;i+=2) a1[tot1++]=t[i];
    for(i=1;i<len;i+=2) a2[tot2++]=t[i];
    for(i=0;i<(len>>1);i++) t[i]=a1[i],t[i+(len>>1)]=a2[i];
    fft(t,len>>1,rev);fft(t+(len>>1),len>>1,rev);
    cp wn=(cp){cos(2*pi/len),rev*sin(2*pi/len)},yh,w=(cp){1,0};
    for(i=0;i<(len>>1);i++,w=(w*wn))
    {
        yh=(w*t[i+(len>>1)]);
        tmp[i]=(t[i]+yh);
        tmp[i+(len>>1)]=(t[i]-yh);
    }
    for(i=0;i<len;i++)t[i]=tmp[i];
}
int main()
{
    int i,len;
    scanf("%d%d",&n,&m);
    n++;m++;
    for(i=0;i<n;i++) scanf("%lf",&a[i].x);
    for(i=0;i<m;i++) scanf("%lf",&b[i].x);
    len=1;
    while(len<n+m-1) len<<=1;
    fft(a,len,1);fft(b,len,1);
    for(i=0;i<len;i++) c[i]=(a[i]*b[i]);
    fft(c,len,-1);
    for(i=0;i<n+m-1;i++) printf("%d ",(int)(c[i].x/len+0.5));
    printf("\n");
    return 0;
}

 之前提到多项式的系数会根据奇偶被分到左右两边,我们事实上可以计算多次递归分奇偶后每一项的最终位置。若某一项在二进制下第i位为0,那么在第i次递归它会被放到左边,因此其最终位置在二进制下的第l-i位就是0(n=2^(l+1))。若第i位为1则第i次递归会被放到右边,最终位置的第l-i位就是1。于是有下面的非递归版fft。

#include <bits/stdc++.h>
using namespace std;
const int N=4e6+11;
const double pi=acos(-1);
int n,m,rev[N],l;

struct cp
{
    double x,y;
    cp operator +(cp q){return (cp){x+q.x,y+q.y};}
    cp operator -(cp q){return (cp){x-q.x,y-q.y};}
    cp operator *(cp q){return (cp){x*q.x-y*q.y,x*q.y+y*q.x};}
}a[N],b[N],c[N],tmp[N]; 
void swap(cp &p,cp &q){cp tt=p;p=q;q=tt;}
void fft(cp* t,int len,int f)
{ 
    if(len==1) return;
    int i,mlen,R,j,k;
    for(i=0;i<len;i++) if(rev[i]<i) swap(t[i],t[rev[i]]);
    cp wn,w,x,y;
    for(mlen=1;mlen<len;mlen<<=1)
    {
        wn=(cp){cos(pi/mlen),f*sin(pi/mlen)};//2被约分了 
        for(j=0,R=(mlen<<1);j<len;j+=R)
        {
            w=(cp){1,0};
            for(k=0;k<mlen;k++,w=(w*wn))
            {
                x=t[j+k];y=(w*t[j+k+mlen]);
                t[j+k]=(x+y);
                t[j+k+mlen]=(x-y);
            }
        } 
    }
    if(f==-1) for(i=0;i<len;i++) t[i].x/=len;
}
int main()
{
    int i,len;
    scanf("%d%d",&n,&m);
    n++;m++;
    for(i=0;i<n;i++) scanf("%lf",&a[i].x);
    for(i=0;i<m;i++) scanf("%lf",&b[i].x);
    len=1;
    while(len<n+m-1) len<<=1,l++;
    for(i=1;i<len;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<l-1);//l-1是因为l位二进制数的最高位是2^(l-1) 
    fft(a,len,1);fft(b,len,1);
    for(i=0;i<len;i++) c[i]=(a[i]*b[i]);
    fft(c,len,-1);
    for(i=0;i<n+m-1;i++) printf("%d ",(int)(c[i].x+0.5));
    printf("\n");
    return 0;
}

 

posted on 2021-08-15 15:26  ANJHZ  阅读(152)  评论(0)    收藏  举报

导航