如何速通卷积?

可能更好的阅读体验

有的题里面卷积是必要的,不会卷积就可能被暴打。

本文旨在帮助和我一样没怎么学多项式的人速通卷积。

其中可能有一些定义和结论,你不需要关心其证明也可以学会卷积,因此本文中不会证明结论。

点值表示法

通过系数表示法给出两个多项式(即给出各项系数) \(f(x)=a_0x^0+\dots+a_nx^n,g(x)=b_0x^0+\dots+b_mx^m\),求 \(h(x)=f(x)g(x)=c_0x^0+\dots+c_{n+m}x^{n+m}\) 即其乘积的各项系数。

结论 1:根据 \(n\) 次多项式 \(f(x)\)\(n+1\) 个不同 \(x\) 处的取值 \((x_1,y_1),(x_2,y_2),\dots,(x_{n+1},y_{n+1})\) 可以唯一确定 \(f(x)\)

定义 1:根据结论 1 可以用 \(n+1\) 个不同 \(x\) 处的取值表示一个 \(n\) 次多项式,将这种表示方法称为点值表示法。

因此可以先求出 \(f(x),g(x)\)\(n+m+1\) 个不同 \(x\) 处的取值,然后相乘即可得到 \(h(x)\)\(n+m+1\) 个不同 \(x\) 处的取值,再根据这些值求出 \(h(x)\) 的各项系数。

于是现在问题变为了在系数表示法和点值表示法之间快速转化。

系数表示法 -> 点值表示法

直接暴力算即可做到 \(O((n+m)^2)\),但是显然不够快。

\(f(x)=f_0(x^2)+xf_1(x^2)\),即将其偶数次系数和奇数次系数分别拿出来组成新的多项式 \(f_0(x),f_1(x)\)

那么只要快速合并即可分治,为了分治可以将项数补到最小且 \(>n+m\)\(2\) 的整数次幂 \(2^p\),但是合并好像很难。

单位根

不过注意到选的数是没有任何限制的,所以不妨找一些有特殊性质的数使其能够快速合并。

定义 2:令平面直角坐标系上的点 \((x,y)\) 表示 \(x+iy\),其中 \(i\) 是虚数单位满足 \(i^2=-1\),将这个平面直角坐标系称为复平面。

复数运算:

typedef double db;
struct cpx{
	db x,y;
};
cpx operator + (const cpx &a,const cpx &b){
	return {a.x+b.x,a.y+b.y};
}
cpx operator - (const cpx &a,const cpx &b){
	return {a.x-b.x,a.y-b.y};
}
cpx operator * (const cpx &a,const cpx &b){
	return {a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};
}
cpx operator / (const cpx &a,const cpx &b){
	return {(a.x*b.x+a.y*b.y)/(b.x*b.x+b.y*b.y),(a.y*b.x-a.x*b.y)/(b.x*b.x+b.y*b.y)};
}

定义 3:将平面直角坐标系上以原点为圆心单位长度为半径的圆称为单位圆。

定义 4:将复平面上的单位圆平均分为 \(n(n\ge2)\) 段且 \((1,0)\) 为其中一个分段点,将从 \((1,0)\) 开始逆时针走到的第 \(2\) 个分段点表示的数称为 \(\omega_n\)。根据三角函数基础知识,可知 \(\omega_n=\cos\frac{2\pi}{2^p}+i\sin\frac{2\pi}{2^p}\)

结论 2:\(\omega_n^k\) 对应从 \((1,0)\) 开始逆时针走到的第 \(k+1\) 个分段点。

结论 3:当 \(2\mid n\) 时,\(-\omega_n^{k+\frac{n}{2}}=\omega_n^k\)

快速合并

不难发现 \(\omega_n^k\) 有一些良好性质,因此考虑令 \(x_i=\omega_{2^p}^{i-1}\)

于是可以注意到当 \(j>2^{p-1}\) 时,\(f(x_j)=f_0(x_j^2)+x_jf_1(x_j^2)=f_0(x_{j-2^{p-1}}^2)-x_{j-2^{p-1}}f_1(x_{j-2^{p-1}}^2)\)

因此只需求出 \(f_0(x_1),\dots,f_0(x_{2^{p-1}}),f_1(x_1),\dots,f_1(x_{2^{p-1}})\) 即可,直接分治即可,时间复杂度 \(O(2^pp)=O((n+m)\log(n+m))\)

卡常

首先要把递归写成循环形式。

考虑将往下分的过程优化。(此过程中需要将偶数次系数和奇数次系数分到两边)

定义 5:将 \(i\) 在这个过程结束后移到的位置称为 \(to_i\)

结论 4:\(to_i\) 即为 \(i\) 的二进制表示将前 \(p\) 位 reverse 得到的数。

因此有递推式:\(to_i=\lfloor\frac{to_{\lfloor\frac{i}{2}\rfloor}}{2}\rfloor+[2\nmid i]2^{p-1}\)

于是可以将该过程优化到线性。

const db PI=acos(-1.0);
int to[N];
void fft(int len,cpx *a){
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
	for(int k=2;k<=len;k<<=1){
		cpx w={cos(PI*2.0/k),sin(PI*2.0/k)};
		for(int i=0;i<len;i+=k){
			cpx x={1,0};
			for(int j=0;j<(k>>1);j++){
				cpx p=a[i+j],q=a[i+j+(k>>1)]*x;
				a[i+j]=p+q,a[i+j+(k>>1)]=p-q;
				x=x*w;
			}
		}
	}
}

点值表示法 -> 系数表示法

直接根据上面代码倒推即可。

void ifft(int len,cpx *a){
	for(int k=len;k>=2;k>>=1){
		cpx w={cos(PI*2.0/k),sin(PI*2.0/k)};
		for(int i=0;i<len;i+=k){
			cpx x={1,0};
			for(int j=0;j<(k>>1);j++){
				cpx p=a[i+j],q=a[i+j+(k>>1)];
				a[i+j]=(p+q)/(cpx){2,0},a[i+j+(k>>1)]=(p-q)/(cpx){2,0}/x;
				x=x*w;
			}
		}
	}
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
}

于是我们已经可以写出卷积代码了:

void convolution_fft(int n,ll *A,int m,ll *B,ll *C){
	int len=1;
	while(len<=(n+m))len<<=1;
	rep(i,0,len-1)a[i]={(db)A[i],0.0};
	rep(i,0,len-1)b[i]={(db)B[i],0.0};
	fft(len,a);
	fft(len,b);
	rep(i,0,len-1)c[i]=a[i]*b[i];
	ifft(len,c);
	rep(i,0,len-1)C[i]=(ll)round(c[i].x);
}

三次变两次优化

原理:\((a+bi)^2=(a^2-b^2)+(2ab)i\)

于是可以将 \(f(x),g(x)\) 的系数分别放在实部和虚部,求平方后虚部除以 \(2\) 便是 \(h(x)\)

cpx a[N];
void convolution_fft(int n,ll *A,int m,ll *B,ll *C){
	int len=1;
	while(len<=(n+m))len<<=1;
	rep(i,0,len-1)a[i]={(db)A[i],(db)B[i]};
	fft(len,a);
	rep(i,0,len-1)a[i]=a[i]*a[i];
	ifft(len,a);
	rep(i,0,len-1)C[i]=(ll)round(a[i].y/2.0);
}

考虑模意义

显然三角函数与浮点数运算会产生精度误差,同时大多数情况下都是在特定模意义下使用卷积,因此考虑使用整数代替这些浮点数运算,只需要在特定模意义中找到和单位根有类似性质的数即可。

可以将 \(mod\) 分解,使用 CRT 合并即可。

一般 \(p-1\)\(2\) 的较高整数次幂因子时可以使用。

原根

定义 6:对于奇质数 \(p\),将满足 \(g^1,\dots,g^{p-1}\) 互不相同的 \(g\) 称为其原根。

结论 5:若 \(n\) 存在原根,则其最小原根是 \(O(n^\frac{1}{4})\) 的。

结论 6:若 \(x\) 不为原根,则 \(\exists y,x^{\frac{p-1}{y}}\equiv 1 \pmod p\)

于是可以暴力枚举找最小原根。

\(998244353\) 的原根是 \(3\)

代替单位根

结论 7:\(g^{\frac{p-1}{2}}\equiv p-1\pmod p\)

因此考虑令 \(x_i=(g^{\frac{mod-1}{2^p}})^{i-1}\)

于是可以注意到当 \(j>2^{p-1}\) 时,\(f(x_j)\equiv f_0(x_j^2)+x_jf_1(x_j^2)\equiv f_0(x_{j-2^{p-1}}^2)-x_{j-2^{p-1}}f_1(x_{j-2^{p-1}}^2)\)

因此只需求出 \(f_0(x_1),\dots,f_0(x_{2^{p-1}}),f_1(x_1),\dots,f_1(x_{2^{p-1}})\) 即可,直接分治即可,时间复杂度 \(O(2^pp)=O((n+m)\log(n+m))\)

const ll mod=998244353;
const ll I2=(mod+1)/2;
const ll G=3;
ll ksm(ll a,ll b,ll p){
	a=a%p;
	ll r=1;
	while(b){
		if(b&1)r=r*a%p;
		a=a*a%p;
		b>>=1;
	}
	return r%p;
}
const ll IG=ksm(G,mod-2,mod);
void ntt(int len,ll *a){
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
	for(int k=2;k<=len;k<<=1){
		ll w=ksm(G,(mod-1)/k,mod);
		for(int i=0;i<len;i+=k){
			ll x=1;
			for(int j=0;j<(k>>1);j++){
				ll p=a[i+j],q=a[i+j+(k>>1)]*x%mod;
				a[i+j]=(p+q)%mod,a[i+j+(k>>1)]=(p-q+mod)%mod;
				x=x*w%mod;
			}
		}
	}
}
void intt(int len,ll *a){
	for(int k=len;k>=2;k>>=1){
		ll w=ksm(IG,(mod-1)/k,mod);
		for(int i=0;i<len;i+=k){
			ll x=1;
			for(int j=0;j<(k>>1);j++){
				ll p=a[i+j],q=a[i+j+(k>>1)];
				a[i+j]=(p+q)*I2%mod,a[i+j+(k>>1)]=(p-q+mod)*I2%mod*x%mod;
				x=x*w%mod;
			}
		}
	}
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
}
ll ntt_a[N],ntt_b[N],ntt_c[N];
void convolution_ntt(int n,ll *A,int m,ll *B,ll *C){
	int len=1;
	while(len<=(n+m))len<<=1;
	rep(i,0,len-1)ntt_a[i]=A[i];
	rep(i,0,len-1)ntt_b[i]=B[i];
	ntt(len,ntt_a);
	ntt(len,ntt_b);
	rep(i,0,len-1)ntt_c[i]=ntt_a[i]*ntt_b[i]%mod;
	intt(len,ntt_c);
	rep(i,0,len-1)C[i]=ntt_c[i];
}

模板题代码

题目链接

#include<bits/stdc++.h>
#define rep(i,l,r) for(int i=(l);i<=(r);i++)
#define per(i,r,l) for(int i=(r);i>=(l);i--)
#define repll(i,l,r) for(ll i=(l);i<=(r);i++)
#define perll(i,r,l) for(ll i=(r);i>=(l);i--)
#define pb push_back
#define ins insert
#define clr clear
using namespace std;
namespace ax_by_c{
typedef long long ll;
const int N=4e6+5;
namespace Bpoly{
typedef double db;
struct cpx{
	db x,y;
};
cpx operator + (const cpx &a,const cpx &b){
	return {a.x+b.x,a.y+b.y};
}
cpx operator - (const cpx &a,const cpx &b){
	return {a.x-b.x,a.y-b.y};
}
cpx operator * (const cpx &a,const cpx &b){
	return {a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};
}
cpx operator / (const cpx &a,const cpx &b){
	return {(a.x*b.x+a.y*b.y)/(b.x*b.x+b.y*b.y),(a.y*b.x-a.x*b.y)/(b.x*b.x+b.y*b.y)};
}
const db PI=acos(-1.0);
int to[N];
void fft(int len,cpx *a){
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
	for(int k=2;k<=len;k<<=1){
		cpx w={cos(PI*2.0/k),sin(PI*2.0/k)};
		for(int i=0;i<len;i+=k){
			cpx x={1,0};
			for(int j=0;j<(k>>1);j++){
				cpx p=a[i+j],q=a[i+j+(k>>1)]*x;
				a[i+j]=p+q,a[i+j+(k>>1)]=p-q;
				x=x*w;
			}
		}
	}
}
void ifft(int len,cpx *a){
	for(int k=len;k>=2;k>>=1){
		cpx w={cos(PI*2.0/k),sin(PI*2.0/k)};
		for(int i=0;i<len;i+=k){
			cpx x={1,0};
			for(int j=0;j<(k>>1);j++){
				cpx p=a[i+j],q=a[i+j+(k>>1)];
				a[i+j]=(p+q)/(cpx){2,0},a[i+j+(k>>1)]=(p-q)/(cpx){2,0}/x;
				x=x*w;
			}
		}
	}
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
}
cpx fft_a[N];
void convolution_fft(int n,ll *A,int m,ll *B,ll *C){
	int len=1;
	while(len<=(n+m))len<<=1;
	rep(i,0,len-1)fft_a[i]={(db)A[i],(db)B[i]};
	fft(len,fft_a);
	rep(i,0,len-1)fft_a[i]=fft_a[i]*fft_a[i];
	ifft(len,fft_a);
	rep(i,0,len-1)C[i]=(ll)round(fft_a[i].y/2.0);
}
const ll mod=998244353;
const ll I2=(mod+1)/2;
const ll G=3;
ll ksm(ll a,ll b,ll p){
	a=a%p;
	ll r=1;
	while(b){
		if(b&1)r=r*a%p;
		a=a*a%p;
		b>>=1;
	}
	return r%p;
}
const ll IG=ksm(G,mod-2,mod);
void ntt(int len,ll *a){
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
	for(int k=2;k<=len;k<<=1){
		ll w=ksm(G,(mod-1)/k,mod);
		for(int i=0;i<len;i+=k){
			ll x=1;
			for(int j=0;j<(k>>1);j++){
				ll p=a[i+j],q=a[i+j+(k>>1)]*x%mod;
				a[i+j]=(p+q)%mod,a[i+j+(k>>1)]=(p-q+mod)%mod;
				x=x*w%mod;
			}
		}
	}
}
void intt(int len,ll *a){
	for(int k=len;k>=2;k>>=1){
		ll w=ksm(IG,(mod-1)/k,mod);
		for(int i=0;i<len;i+=k){
			ll x=1;
			for(int j=0;j<(k>>1);j++){
				ll p=a[i+j],q=a[i+j+(k>>1)];
				a[i+j]=(p+q)*I2%mod,a[i+j+(k>>1)]=(p-q+mod)*I2%mod*x%mod;
				x=x*w%mod;
			}
		}
	}
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
}
ll ntt_a[N],ntt_b[N],ntt_c[N];
void convolution_ntt(int n,ll *A,int m,ll *B,ll *C){
	int len=1;
	while(len<=(n+m))len<<=1;
	rep(i,0,len-1)ntt_a[i]=A[i];
	rep(i,0,len-1)ntt_b[i]=B[i];
	ntt(len,ntt_a);
	ntt(len,ntt_b);
	rep(i,0,len-1)ntt_c[i]=ntt_a[i]*ntt_b[i]%mod;
	intt(len,ntt_c);
	rep(i,0,len-1)C[i]=ntt_c[i];
}
};
int n,m;
ll a[N],b[N],c[N];
void slv(int _csid,int _csi){
	scanf("%d %d",&n,&m);
	rep(i,0,n)scanf("%lld",&a[i]);
	rep(i,0,m)scanf("%lld",&b[i]);
//	Bpoly::convolution_fft(n,a,m,b,c);
	Bpoly::convolution_ntt(n,a,m,b,c);
	rep(i,0,n+m)printf("%lld ",c[i]);
}
void main(){
//	ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
	int T=1,csid=0;
//	scanf("%d",&csid);
//	scanf("%d",&T);
	rep(i,1,T)slv(csid,i);
}
}
int main(){
	string __name="";
	if(__name!=""){
		freopen((__name+".in").c_str(),"r",stdin);
		freopen((__name+".out").c_str(),"w",stdout);
	}
	ax_by_c::main();
	return 0;
}
posted @ 2025-06-16 16:14  ax_by_c  阅读(59)  评论(0)    收藏  举报