MTT - 拆系数 FFT / 任意模数 NTT

更新日志 2025/06/25:开工。

前言

MTT,即为任意模数 NTT,其模数并不具有特殊性质使得 NTT 无效,而 FFT 则面临严重的精度问题。

事实上 MTT 有两种实现方式——拆系数 FFT 与三模数 NTT。

其实三模数 NTT 不难,就是选三个合适的模数跑三遍 NTT,然后用 CRT 求即可。这里详细讲解前者实现,因为常数更小,且无需 CRT 等数论知识(懒)。

首先你应当会 FFT。如果感兴趣的话也可以看看 NTT,但对于本篇文章内容来说非必要。

整合封装模板可能需要搭配缺省源使用,可以去我首页置顶的缺省源归档复制一份。

问题

给定两个多项式与任意一个模数,求两个多项式的积,各系数对给出的模数取模。

实现

我们考虑减小 FFT 的精度误差。

\(M=\lceil\sqrt {mod} \rceil\),并把两个多项式每个系数都拆成 \(kM+b\) 的形式,使 \(a<M\land b<M\)

然后我们暴力展开乘法运算:

\[(k_1M+b_1)(k_2M+b_2)=k_1k_2M^2+(k_1b_2+k_2b_1)M+b_1b_2 \]

这样运算过程中的每一个数都不会超过 \(10^9\),精度问题大大减小。

但直接对上面的式子进行 FFT 的话,我们需要进行 \(4\) 次 DFT 与 \(4\) 次 IDFT,足足 \(8\) 次 FFT,常数未免太大了。

我在 FFT 那篇文章已经介绍过一种缩减常数的方法,这里挂个链接,这个优化在这里就很有用了。

我们将 DFT 和 IDFT 操作依次两两配对,就能把 FFT 次数减为 \(4\) 次。还有更优的做法,鉴于过于复杂且实际用途不大以及我不会这里就不做讲解了。

特别的,进行 IDFT 操作后,无需执行其余运算,结果的实部和虚部分别除 \(lim\) 就是存在两个位置各自的结果了。这是因为这两次 IDFT 操作后的结果必然是实数,实部虚部互不影响。

整合封装模板

namespace MTT{
    int P;
    struct clx{
        flt x,y;
        clx(flt a=0,flt b=0){x=a,y=b;}
        inline clx operator+=(const clx &b){return x+=b.x,y+=b.y,*this;}
        friend inline clx operator+(clx a,clx b){return a+=b;}
        inline clx operator-=(const clx &b){return x-=b.x,y-=b.y,*this;}
        friend inline clx operator-(clx a,clx b){return a-=b;}
        inline clx operator*=(const clx &b){return *this=clx(x*b.x-y*b.y,x*b.y+y*b.x);}
        friend inline clx operator*(clx a,clx b){return a*=b;}
        inline clx operator!(){return clx(x,-y);}
    };
    typedef vec<ll> poly;
    typedef vec<clx> Poly;
    const flt Pi=acos(-1);
    vec<int> Rt;
    inline void fft(int lim,Poly &a,int type){
        repl(i,0,lim)if(i<Rt[i])swap(a[i],a[Rt[i]]);
        for(int mid=1;mid<lim;mid<<=1){
            clx w1(cos(Pi/mid),type*sin(Pi/mid));
            for(int j=0;j<lim;j+=(mid<<1)){
                clx w(1,0);
                repl(k,0,mid){
                    clx x=a[j+k],y=w*a[j+mid+k];
                    a[j+k]=x+y;
                    a[j+mid+k]=x-y;
                    w=w*w1;
                }
            }
        }
    }
    inline void operator*=(poly &x,poly y){
        int n=x.size(),m=y.size(),M=ceil(sqrt(P));
        int lim=1,l=0,len=n+m-1;
        while(lim<len)lim<<=1,l++;
        Rt.resize(lim);
        repl(i,0,lim)Rt[i]=(Rt[i>>1]>>1)|((i&1)<<(l-1));
        Poly p(lim),q(lim),s(lim),t(lim);
        x.resize(lim),y.resize(lim);
        repl(i,0,lim)p[i]=clx(x[i]/M,x[i]%M),q[i]=clx(y[i]/M,y[i]%M);
        fft(lim,p,1);fft(lim,q,1);
        repl(i,0,lim){
            clx ka=(p[i]+!p[i?lim-i:i])*clx(0.5,0);
            clx ba=(p[i]-!p[i?lim-i:i])*clx(0,-0.5);
            clx kb=(q[i]+!q[i?lim-i:i])*clx(0.5,0);
            clx bb=(q[i]-!q[i?lim-i:i])*clx(0,-0.5);
            s[i]=ka*kb+ka*bb*clx(0,1);
            t[i]=ba*kb+ba*bb*clx(0,1);
        }
        fft(lim,s,-1);fft(lim,t,-1);
        repl(i,0,lim){
            ll a=(ll)(s[i].x/lim+0.5)%P;
            ll b=(ll)(s[i].y/lim+0.5)%P;
            ll c=(ll)(t[i].x/lim+0.5)%P;
            ll d=(ll)(t[i].y/lim+0.5)%P;
            x[i]=a*M*M%P+(b+c)*M%P+d,x[i]%=P;
        }
        x.resize(len);
        for(auto &i:x)i=(i+P)%P;
    }
    inline void operator+=(poly &a,poly b){
        if(a.size()<b.size())a.resize(b.size());
        repl(i,0,b.size())a[i]=a[i]+b[i],a[i]%=P;
    }
    inline void operator-=(poly &a,poly b){
        if(a.size()<b.size())a.resize(b.size());
        repl(i,0,b.size())a[i]=a[i]-b[i]+P,a[i]%=P;
    }
    poly operator*(poly a,poly b){return a*=b,a;}
    poly operator+(poly a,poly b){return a+=b,a;}
    poly operator-(poly a,poly b){return a-=b,a;}
}using namespace MTT;
posted @ 2025-06-25 18:39  LastKismet  阅读(71)  评论(0)    收藏  举报