多项式

傅里叶变换(FFT)学习笔记

NTT与多项式全家桶

FFT

常量与变量

  • const double Pi\texttt{const double Pi}π的值\pi 的值
  • int rev[i]\texttt{int rev[i]}aia_i 变换之后的位置。
  • int lim\texttt{int lim}:大于 2n2n 的最小的 22 的次幂。
  • int len\texttt{int len}log2lim\log_2lim

函数

  • struct Complex\texttt{struct Complex}:功能简陋的虚数结构体。
  • void init(int n)\texttt{void init(int n)}:最高项次数为 nn 时的初始化。
  • void fft(Complex *a,double flag)\texttt{void fft(Complex *a,double flag)}flag=1flag=1 时,将多项式 aa 由系数转为点值表示,flag=1flag=-1 时,将多项式 aa 由点值转为系数表示。

代码

struct Complex{
	double x,y;
	Complex(double xx=0,double yy=0){
		x=xx,y=yy;
	}
	Complex operator*(const Complex q)const{
		return Complex(x*q.x-y*q.y,x*q.y+y*q.x);
	}
	Complex operator+(const Complex q)const{
		return Complex(x+q.x,y+q.y);
	}
	Complex operator-(const Complex q)const{
		return Complex(x-q.x,y-q.y);
	}
}a[N],b[N];
struct FFT{
	const double Pi=acos(-1.0);
	int rev[N],lim,len;
	void init(int n){
		lim=1,len=0;
		while(lim<2*n)lim<<=1,len++;
		for(int i=1;i<=lim;i++)
			rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
	}
	void fft(Complex *a,double flag){
		for(int i=0;i<lim;i++)
			if(i<rev[i])swap(a[i],a[rev[i]]);
		for(int i=1;i<lim;i<<=1){
			Complex w1(cos(Pi/i),flag*sin(Pi/i));
			for(int j=0;j<lim;j+=(i<<1)){
				Complex w(1,0),x,y;
				for(int k=0;k<i;k++,w=w*w1){
					x=a[j+k],y=a[i+j+k]*w;
					a[j+k]=x+y;a[i+j+k]=x-y;
				}
			}
		}
		if(flag<0){
			for(int i=0;i<lim;i++)
				a[i].x/=lim,a[i].y/=lim;
		}
	}
}f;

NTT

其中开方运算若多项式常数项不为 11 则要用到二次剩余求解。

#define clr(f,n) memset(f,0,sizeof(ll)*(n))
#define cpy(f,g,n) memcpy(f,g,sizeof(ll)*(n))
struct Poly{
	const ll G=3,mod=998244353;
	ll qmi(ll a,int b=998244351){
		ll ans=1;
		while(b){
			if(b&1)ans=ans*a%mod;
			a=a*a%mod;b>>=1;
		}
		return ans;
	}
	const ll invG=qmi(G);
	int rv[N<<1],trv,inv[N<<1];
	void init(int n){
		if(trv==n)return;trv=n;
		inv[1]=1;
		for(int i=0;i<n;i++)
			rv[i]=(rv[i>>1]>>1)|((i&1)?n>>1:0);
		for(int i=2;i<=trv;i++)
			inv[i]=1ll*inv[mod%i]*(mod-mod/i)%mod;
	}
	void print(ll *f,int n){
		for(int i=0;i<n;i++)printf("%lld ",f[i]);puts("");
	}
	void NTT(ll *g,bool op,int n){
		init(n);
		static ull f[N<<1],w[N<<1];
		w[0]=1;
		for(int i=0;i<n;i++)f[i]=((mod<<5)+g[rv[i]])%mod;
		for(int l=1;l<n;l<<=1){
			ull w1=qmi(op?G:invG,(mod-1)/(l<<1));
			for(int i=1;i<l;i++)w[i]=w[i-1]*w1%mod;
			for(int i=0;i<n;i+=(l<<1)){
				for(int j=0,tt;j<l;j++){
					tt=w[j]*f[i|l|j]%mod;
					f[i|l|j]=f[i|j]+mod-tt;
					f[i|j]+=tt;
				}
			}
			if(l==(1<<10))
				for(int i=0;i<n;i++)f[i]%=mod;
		}
		if(!op){
			ull invn=qmi(n);
			for(int i=0;i<n;i++)
				g[i]=f[i]%mod*invn%mod;
		}
		else for(int i=0;i<n;i++)g[i]=f[i]%mod;
	}
	void pointx(ll *f,ll *g,int n){
		for(int i=0;i<n;i++)f[i]=f[i]*g[i]%mod;
	}
	void times(ll *f,ll *g,int len,int lim){
		static ll sav[N<<1];
		int n=1;for(n;n<(len<<1);n<<=1);
		clr(f+len,n-len);clr(g+len,n-len);/*ex*/
		clr(sav,n);cpy(sav,g,n);
		NTT(f,1,n);NTT(sav,1,n);
		pointx(f,sav,n);NTT(f,0,n);
		clr(f+lim,n-lim);clr(sav,n);
	}
	void invp(ll *f,int m){
		int n;for(n=1;n<m;n<<=1);
		static ll w[N<<1],r[N<<1],sav[N<<1];
		w[0]=qmi(f[0]);
		for(int len=2;len<=n;len<<=1){
			for(int i=0;i<(len>>1);i++)r[i]=2*w[i]%mod;
			cpy(sav,f,len);
			NTT(w,1,len<<1);pointx(w,w,len<<1);
			NTT(sav,1,len<<1);pointx(w,sav,len<<1);
			NTT(w,0,len<<1);clr(w+len,len);
			for(int i=0;i<len;i++)w[i]=(r[i]-w[i]+mod)%mod;
		}
		cpy(f,w,m);clr(sav,n<<1);clr(r,n<<1);clr(w,n<<1);
	}
	void dao(ll *f,int m){
		for(int i=1;i<m;i++)f[i-1]=f[i]*i%mod;
		f[m-1]=0;
	}
	void jifen(ll *f,int m){
		for(int i=m;i;i--)f[i]=f[i-1]*inv[i]%mod;
		f[0]=0;
	}
	void lnp(ll *f,int n){
		static ll f_[N<<1];
		cpy(f_,f,n);dao(f_,n);invp(f,n);
		times(f,f_,n,n-1);jifen(f,n-1);clr(f_,n);
	}
	void exp(ll *f,int n){
		static ll a[N<<1],b[N<<1];
		int len=1;for(;len<n;len<<=1);
		cpy(a,f,n);clr(f,len);b[0]=f[0]=1;
		for(int l=2;l<=len;l<<=1){
			cpy(b,f,l>>1);lnp(f,l);
			for(int i=0;i<l;i++)
				f[i]=(a[i]-f[i]+mod)%mod;
			f[0]=(f[0]+1)%mod;
			times(f,b,l,l);
		}
		clr(a,len);clr(b,len);
	}
	void sqrtp(ll *f,int n){
		static ll a[N<<1],b[N<<1],n2=qmi(2);
		int len=1;for(;len<n;len<<=1);
		cpy(a,f,n);clr(f,len);f[0]=1;
		for(int l=2;l<=len;l<<=1){
			cpy(b,f,l>>1);times(f,f,l,l);
			for(int i=0;i<l;i++)
				f[i]=(f[i]+a[i]+mod)%mod*n2%mod;
			invp(b,l);times(f,b,l,l);
		}
		clr(a,len);clr(b,len);
	}
	void qmip(ll *f,ll k1,ll k2,int n){
		ll m=0;
		for(;m<n;m++)if(f[m])break;
		if(f[m]==0)return;
		for(int i=m;i<n;i++)f[i-m]=f[i];
		ll xk=qmi(f[0],k2),nx=qmi(f[0]);
		for(int i=0;i<n-m;i++)f[i]=f[i]*nx%mod;
		lnp(f,n-m);
		for(int i=0;i<n-m;i++)f[i]=f[i]*k1%mod;
		exp(f,n-m);
		for(int i=0;i<n-m;i++)f[i]=f[i]*xk%mod;
		m=m*k1;
		for(int i=n-1;i>=m;i--)f[i]=f[i-m];
		for(int i=0;i<m&&i<n;i++)f[i]=0;
	}
	void divp(ll *f,ll *g,ll *q,ll *r,int n,int m){
		clr(q,n);clr(r,n);
		for(int i=0;i<n-m+1;i++)q[i]=f[n-1-i];
		for(int i=0;i<n-m+1&&i<m;i++)r[i]=g[m-1-i];
		invp(r,n-m+1);times(q,r,n-m+1,n-m+1);
		for(int i=0;i*2<n-m+1;i++)swap(q[i],q[n-m-i]);
		clr(g+m-1,n-m+1);clr(r,n-m+1);cpy(r,q,m-1);
		times(g,r,m-1,m-1);
		for(int i=0;i<m-1;i++)r[i]=(f[i]-g[i]+mod)%mod;
	}
}poly;
posted @ 2023-12-15 13:39  luckydrawbox  阅读(9)  评论(0)    收藏  举报  来源