分治FFT学习笔记

引入:

我们先来看一道例题:

给定序列\(\{g_1,\dots g_{n-1}\}\),已知\(f_0=1\)\(f_n=\sum_{i=1}^{n} f_{n-i}\times g_i\),求序列\(\{f_i\}\),对\(998244353\)取模

考虑最朴素的做法,\(O(n^2)\),在\(n\)比较小的情况下是可以的

考虑\(f_i\)的计算显然可以化成卷积的形式,但\(f_i\)的计算依赖于之前的\(f\),直接\(NTT\),退化成\(O(n^2 \log n)\)

那么在\(n\)比较大的情况下,我们要怎么来做这个东西呢?

正题:

在引入,我们得到一个\(O(n^2 \log n)\)的做法,接下来,考虑如何优化这个做法

我们考虑分治,现在我们要计算\([l,r]\)\(f\),把它分成\([l,mid]\)\([mid+1,r]\),现在假设我们已经算出了\([l,mid]\)

考虑计算\([l,mid]\)\(f\)\([mid+1,r]\)\(f\)的贡献,对于一个\(mid< x \le r\),设\(w[x]\)\([l,mid]\)\(x\)的贡献

\[w[x]=\sum_{i=l}^{mid} f_i \times g_{mid-i}\\ w[x]=\sum_{i=l}^{x} f_i \times g_{x-i}\\ \]

我们可以直接补到\(x\),因为大于\(mid\)的部分\(f\)\(0\),可以发现,\(w[x]\)的计算显然可以写成卷积的形式

在这里,我们令\(a[i]=f_{i+l}\),令\(b[i]=g_{i+1}\),那么,\(w[x]\)可以写成这样

\[w[x]=\sum_{i=0}^{x-l-1} a[i]\times b[x-l-1-i] \]

则我们可以一次\(NTT\)直接算出这一部分的贡献,然后继续分治即可,时间复杂度\(O(n \log ^2 n)\)

Code:

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod=998244353;
const int N=2e5+11;
int n,g[N],f[N],A[N],B[N],p[N];
int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-f;ch=getchar();}
    while(isdigit(ch)){x=x*10+ch-48;ch=getchar();}
    return x*f;
}
int qpow(int x,int y){
    int re=1;
    while(y>0){
        if(y&1) re=re*x%mod;
        y>>=1;x=x*x%mod;
    }return re;
}
void NTT(int *a,int flag,int len){
    for(int i=0;i<len;i++)
        if(i<p[i]) swap(a[i],a[p[i]]);
    for(int l=2;l<=len;l<<=1){
        int wn=qpow(3,(mod-1)/l);
        if(flag==-1) wn=qpow(wn,mod-2);
        for(int st=0;st<len;st+=l){
            int w=1;
            for(int u=st;u<st+(l>>1);u++,w=w*wn%mod){
                int x=a[u],y=w*a[u+(l>>1)]%mod;
                a[u]=(x+y)%mod;a[u+(l>>1)]=(x+mod-y)%mod;
            }
        }
    }
}
void Transform(int len){
    NTT(A,1,len);NTT(B,1,len);
    for(int i=0;i<=len;i++) A[i]=A[i]*B[i]%mod;
    NTT(A,-1,len);int inv=qpow(len,mod-2);
    for(int i=0;i<len;i++) A[i]=A[i]*inv%mod;
}
void DivideT(int l,int r){
    if(l==r) return ;
    int mid=l+r>>1;
    DivideT(l,mid);
    int len=1,tim=0,sz=r-l-1;
    while(len<=sz) len<<=1,++tim;
    for(int i=0;i<len;i++)
        p[i]=(p[i>>1]>>1)|((i&1)<<(tim-1));
    for(int i=0;i<len;i++) A[i]=B[i]=0;
    for(int i=0;i<=mid-l;i++) A[i]=f[i+l];
    for(int i=0;i<=r-l-1;i++) B[i]=g[i+1];
    Transform(len);
    for(int i=mid+1;i<=r;i++) f[i]=(f[i]+A[i-l-1])%mod;
    DivideT(mid+1,r);
}
signed main(){
    n=read();f[0]=1;
    for(int i=1;i<n;i++) g[i]=read();
    DivideT(0,n-1);
    for(int i=0;i<n;i++) printf("%lld ",f[i]);
    return 0;
}

posted @ 2020-02-11 23:18  DQY_dqy  阅读(117)  评论(2编辑  收藏  举报