多项式板子

ntt版本

#include<bits/stdc++.h>
using namespace std;
const int p=998244353,G=3,Gi=332748118;
const int N=8e6+5;
int rev[N],rt[N],Inv[N];
inline void add(int &x,int y) {x+=y;if(x>=p) x-=p;}
inline int qpow(int a,int b) {
	int ans=1;
	while(b) {
		if(b&1) ans=1ll*ans*a%p;
		a=1ll*a*a%p;
		b>>=1;
	}
	return ans;
}
inline void prep(int len) {
	for(int i=0; i<len; i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(len>>1):0);
	for(int i=1; i<len; i<<=1) {
		int val=qpow(G,(p-1)/(i<<1));
		rt[i]=1;for(int j=1; j<i; j++) rt[j+i]=1ll*rt[j+i-1]*val%p;
	}
}
inline void ntt(vector<int> &a,int len,int f) {
	if(f==-1) reverse(a.begin()+1,a.end());
	for(int i=1; i<len; i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(int i=1; i<len; i<<=1) {
		for(int j=0; j<len; j+=(i<<1)) {
			for(int k=0; k<i; k++) {
				int x=a[j+k],y=1ll*rt[k+i]*a[j+k+i]%p;
				a[j+k]=(x+y>=p?x+y-p:x+y);
				a[j+k+i]=(x-y<0?x-y+p:x-y);
			}
		}
	}
	if(f==-1) {
		int inv=qpow(len,p-2);
		for(int i=0; i<len; i++) a[i]=1ll*a[i]*inv%p;
	}
}
struct poly:vector<int> {
	using vector<int>::vector;
	friend void read(poly &x,int n) {x.resize(n);for(int i=0; i<n; i++) cin>>x[i];}
	friend void write(poly &x,int n=-1) {x.resize(n);if(!~n) n=x.size();for(int i=0; i<n; i++) cout<<x[i]<<" ";cout<<"\n";}
	poly operator+(const poly &o) const {
		poly a=*this,b=o;
		int len=max(a.size(),b.size());
		a.resize(len),b.resize(len);
		for(int i=0; i<len; i++) a[i]=(a[i]+b[i])%p;
		return a;
	}
	poly operator-(const poly &o) const {
		poly a=*this,b=o;
		int len=max(a.size(),b.size());
		a.resize(len),b.resize(len);
		for(int i=0; i<len; i++) a[i]=(a[i]-b[i]+p)%p;
		return a;
	}
	poly operator*(const poly &o) const {
		poly a=*this,b=o;
		int need=a.size()+b.size()-1,len=1;
		while(len<need) len<<=1;
		a.resize(len),b.resize(len);
		prep(len);
		ntt(a,len,1),ntt(b,len,1);
		for(int i=0;i<len;i++) a[i]=1ll*a[i]*b[i]%p;
		ntt(a,len,-1);
		a.resize(need);
		return a;
	}
	poly operator*(const int &o) const {
		poly a=*this;
		for(int i=0; i<a.size(); i++) a[i]=(1ll*a[i]*o%p+p)%p;
		return a;
	}
	poly operator<<(const int &o) const {
		poly a=*this,b;
		b.resize(a.size()+o);
		for(int i=0; i<o; i++) b[i]=0;
		for(int i=o; i<b.size(); i++) b[i]=a[i-o];
		return b;
	}
	poly operator>>(const int &o) const {
		poly a=*this,b;
		if(o>=a.size()) return b;
		b.resize(a.size()-o);
		for(int i=0; i<b.size(); i++) b[i]=a[i+o];
		return b;
	}
	poly inv() {
		poly g=*this,f,a,b;
		int n=g.size(),len=2;g.resize(2*n);
		f.resize(1),f[0]=qpow(g[0]%p,p-2);
		while(len<2*n) {
			f.resize(len,0),a.resize(len<<1,0),b.resize(len<<1,0);
			for(int i=0; i<len; i++) a[i]=g[i],b[i]=f[i];
			for(int i=0; i<len; i++) f[i]=2ll*b[i]%p;
			prep(len<<1);
			ntt(a,len<<1,1),ntt(b,len<<1,1);
			for(int i=0; i<(len<<1); i++) a[i]=1ll*a[i]*b[i]%p*b[i]%p;
			ntt(a,len<<1,-1);
			for(int i=0; i<len; i++) add(f[i],p-a[i]);
			len<<=1;
		}
		f.resize(n);
		return f;
	}
	poly operator/(const poly &o) const {
		poly a=*this,b=o;
		int n=a.size(),m=b.size();
		if(n<m) return poly();
		reverse(a.begin(),a.end());
		reverse(b.begin(),b.end());
		b.resize(n-m+2);
		a=a*b.inv();
		a.resize(n-m+1);
		reverse(a.begin(),a.end());
		return a;
	}
	poly operator%(const poly &o) const {
		poly a=*this,b=o;
		poly c=a-a/b*b;
		c.resize(b.size()-1);
		return c;
	}
	poly deriv() {
		poly a=*this;
		for(int i=1; i<a.size(); i++) a[i-1]=1ll*a[i]*i%p;
		a.pop_back();
		return a;
	}
	poly integ() {
		poly a=*this;
		a.push_back(0);
		Inv[1]=1;
		for(int i=2; i<a.size(); i++) Inv[i]=1ll*(p-p/i)*Inv[p%i]%p;
		for(int i=a.size()-1; i>=1; i--) a[i]=1ll*a[i-1]*Inv[i]%p;
		a[0]=0;
		return a;
	}
	poly ln() {
		poly a=*this;
		int n=a.size();
		a=(a.deriv()*a.inv()).integ();
		a.resize(n);
		return a;
	}
	poly exp() {
		poly a=*this,b;
		int n=a.size(),len=1;
		b.resize(1),b[0]=1;
		while(len<n) {
			len<<=1;
			poly t=b;t.resize(len);
			poly s(a.begin(),a.begin()+min((int)a.size(),len)),ln_t=t.ln();
			s.resize(len);
			for(int i=0; i<len; i++) add(s[i],p-ln_t[i]);
			add(s[0],1);
			b=t*s,b.resize(len);
		}
		b.resize(n);
		return b;
	}
	poly sqrt() {
		poly a=*this,g;
		int n=a.size();
		g.resize(1),g[0]=1;
		int len=1,inv2=(p+1)/2;
		while(len<n) {
			len<<=1;
			poly t=g;t.resize(len);
			poly inv_t=t.inv(),s(a.begin(),a.begin()+min((int)a.size(),len));
			s.resize(len),s=s*inv_t;
			for(int i=0; i<len; i++) s[i]=1ll*(s[i]+t[i])*inv2%p;
			g=s,g.resize(len);
		}
		g.resize(n);
		return g;
	}
};

fft版本

#include<bits/stdc++.h>
#define int long long
using namespace std;
const double pi=acos(-1.0);
const int N=4e6+5;
int rev[N];
struct Complex {
	double x,y;
	Complex(double X=0.0,double Y=0.0) {x=X,y=Y;}
	Complex operator+(const Complex &t) const {return Complex(x+t.x,y+t.y);}
	Complex operator-(const Complex &t) const {return Complex(x-t.x,y-t.y);}
	Complex operator*(const Complex &t) const {return Complex(x*t.x-y*t.y,x*t.y+y*t.x);}
};
void change(vector<Complex> &x,int len) {
	for(int i=0; i<len; i++) {
		rev[i]=rev[i>>1]>>1;
		if(i&1) rev[i]|=(len>>1);
	}
	for(int i=0; i<len; i++) if(rev[i]>i) swap(x[i],x[rev[i]]);
}
void fft(vector<Complex> &x,int len,int flag) {
	change(x,len);
	for(int i=2; i<=len; i*=2) {
		Complex wn=Complex(cos(2*pi/i),sin(flag*2*pi/i));
		for(int j=0; j<len; j+=i) {
			Complex w=Complex(1,0);
			for(int k=j; k<j+i/2; k++) {
				Complex l=x[k],r=w*x[k+i/2];
				x[k]=l+r,x[k+i/2]=l-r;
				w=w*wn;
			}
		}
	}
	if(flag==-1) for(int i=0; i<len; i++) x[i].x/=len;
}
struct poly {
	mutable vector<Complex> x;
	void read(int n) {x.resize(n);for(int i=0; i<n; i++) {int a;cin>>a;x[i]=Complex(a,0);}}
	void write(int n=-1) {if(!~n) n=x.size();for(int i=0; i<n; i++) cout<<(int)round(x[i].x)<<" ";cout<<"\n";}
	poly operator*(const poly &o) const {
		poly a=*this,b=o;
		int len=1,tmp=a.x.size()+b.x.size()-1;
		while(len<tmp) len<<=1;
		a.x.resize(len),b.x.resize(len);
		fft(a.x,len,1),fft(b.x,len,1);
		for(int i=0; i<len; i++) a.x[i]=a.x[i]*b.x[i];
		fft(a.x,len,-1);
		a.x.resize(tmp);
		return a;
	}
};
posted @ 2025-02-04 17:07  System_Error  阅读(34)  评论(1)    收藏  举报