多项式板子

ntt版本

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int p=998244353,G=3;
const int N=1e6+5;
int rev[N];
int qpow(int a,int b) {
	int ans=1;
	while(b) {
		if(b&1) ans=ans*a%p;
		a=a*a%p;
		b>>=1;
	}
	return ans;
}
void prep(int len) {
    for(int i=0; i<len; i++) {
		rev[i]=rev[i>>1]>>1;
		if(i&1) rev[i]|=(len>>1);
	}
}
void change(vector<int> &x,int len) {for(int i=0; i<len; i++) if(rev[i]>i) swap(x[i],x[rev[i]]);}
void ntt(vector<int> &x,int len,int flag) {
	change(x,len);
	for(int i=1; i<len; i<<=1) {
		int gn=qpow(flag==1?G:qpow(G,p-2),(p-1)/(i<<1));
		for(int j=0; j<len; j+=(i<<1)) {
			int g=1;
			for(int k=j; k<j+i; k++) {
				int l=x[k],r=g*x[k+i]%p;
				x[k]=(l+r>=p?l+r-p:l+r),x[k+i]=(l-r<0?l-r+p:l-r);
				g=g*gn%p;
			}
		}
	}
	if(flag==-1) {
		int inv=qpow(len,p-2);
		for(int i=0; i<len; i++) x[i]=x[i]*inv%p;
	}
}
struct poly:vector<int> {
	friend void read(poly &x,int n) {x.resize(n);for(int i=0; i<n; i++) cin>>x[i];}
	friend void write(poly &x,int n=-1) {x.resize(n);if(!~n) n=x.size();for(int i=0; i<n; i++) cout<<x[i]<<" ";cout<<"\n";}
	poly operator+(const poly &o) const {
		poly a=*this,b=o;
		int len=max(a.size(),b.size());
		a.resize(len),b.resize(len);
		for(int i=0; i<len; i++) a[i]=(a[i]+b[i])%p;
		return a;
	}
	poly operator-(const poly &o) const {
		poly a=*this,b=o;
		int len=max(a.size(),b.size());
		a.resize(len),b.resize(len);
		for(int i=0; i<len; i++) a[i]=(a[i]-b[i]+p)%p;
		return a;
	}
	poly operator*(const poly &o) const {
		poly a=*this,b=o;
		int len=1,tmp=a.size()+b.size()-1;
		while(len<tmp) len<<=1;
		a.resize(len),b.resize(len);
        prep(len);
		ntt(a,len,1),ntt(b,len,1);
		for(int i=0; i<len; i++) a[i]=a[i]*b[i]%p;
		ntt(a,len,-1);
		a.resize(tmp);
		return a;
	}
	poly operator*(const int &o) const {
		poly a=*this;
		for(int i=0; i<a.size(); i++) a[i]=(a[i]*o%p+p)%p;
		return a;
	}
	poly operator<<(const int &o) const {
		poly a=*this,b;
		b.resize(a.size()+o);
		for(int i=0; i<o; i++) b[i]=0;
		for(int i=o; i<b.size(); i++) b[i]=a[i-o];
		return b;
	}
	poly operator>>(const int &o) const {
		poly a=*this,b;
		if(o>=a.size()) return b;
		b.resize(a.size()-o);
		for(int i=0; i<b.size(); i++) b[i]=a[i+o];
		return b;
	}
	poly inv() {
		poly a=*this,b,s,t;
		int n=a.size();
		a.resize(n<<1),b.resize(n<<1);
		b[0]=qpow(a[0]%p,p-2);
		int len=2;
		while(len<(n<<1)) {
			s.resize(len),t.resize(len);
			for(int i=0; i<len; i++) s[i]=a[i];
			for(int i=0; i<len; i++) t[i]=b[i];
			for(int i=0; i<len; i++) b[i]=2*t[i]%p;
			s=s*t*t;
			for(int i=0; i<len; i++) b[i]=(b[i]-s[i]+p)%p;
			len<<=1;
		}
		b.resize(n);
		return b;
	}
	poly operator/(const poly &o) const {
		poly a=*this,b=o;
		int n=a.size(),m=b.size();
		reverse(a.begin(),a.end());
		reverse(b.begin(),b.end());
		b.resize(n-m+2);
		a=a*b.inv();
		a.resize(n-m+1);
		reverse(a.begin(),a.end());
		return a;
	}
	poly operator%(const poly &o) const {
		poly a=*this,b=o;
		poly c=a-a/b*b;
		c.resize(b.size()-1);
		return c;
	}
	poly dao() {
		poly a=*this;
		for(int i=0; i<a.size()-1; i++) a[i]=a[i+1]*(i+1)%p;
		a.pop_back();
		return a;
	}
	poly ji() {
		poly a=*this;a.push_back(0);
		for(int i=a.size()-1; i>=1; i--) a[i]=a[i-1]*qpow(i,p-2)%p;
		a[0]=0;
		return a;
	}
	poly ln() {
		poly a=*this;
		int n=a.size();
		a=(a.dao()*a.inv()).ji();
		a.resize(n);
		return a;
	}
	poly exp() {
		poly a=*this,b,s,t;
		int n=a.size();
		b.resize(n<<1);
		b[0]=1;
		int len=2;
		while(len<(n<<1)) {
			s.resize(len),t.resize(len);
			for(int i=0; i<len; i++) s[i]=a[i];
			for(int i=0; i<len; i++) t[i]=b[i];
			s=t.ln()*(-1)+s;
			s[0]=(s[0]+1)%p;
			s=s*t;
			for(int i=0; i<len; i++) b[i]=s[i];
			len<<=1;
		}
		b.resize(n);
		return b;
	}
	poly sqrt() {
		poly a=*this,b,s,t;
		int n=a.size();
		b.resize(n<<1);
		b[0]=1;
		int len=2;
		while(len<(n<<1)) {
			s.resize(len),t.resize(len);
			for(int i=0; i<len; i++) s[i]=a[i];
			for(int i=0; i<len; i++) t[i]=b[i];
			s=(s+t*t)*(t*2).inv();
			for(int i=0; i<len; i++) b[i]=s[i];
			len<<=1; 
		}
		b.resize(n);
		return b;
	}
};

fft版本

#include<bits/stdc++.h>
#define int long long
using namespace std;
const double pi=acos(-1.0);
const int N=4e6+5;
int rev[N];
struct Complex {
	double x,y;
	Complex(double X=0.0,double Y=0.0) {x=X,y=Y;}
	Complex operator+(const Complex &t) const {return Complex(x+t.x,y+t.y);}
	Complex operator-(const Complex &t) const {return Complex(x-t.x,y-t.y);}
	Complex operator*(const Complex &t) const {return Complex(x*t.x-y*t.y,x*t.y+y*t.x);}
};
void change(vector<Complex> &x,int len) {
	for(int i=0; i<len; i++) {
		rev[i]=rev[i>>1]>>1;
		if(i&1) rev[i]|=(len>>1);
	}
	for(int i=0; i<len; i++) if(rev[i]>i) swap(x[i],x[rev[i]]);
}
void fft(vector<Complex> &x,int len,int flag) {
	change(x,len);
	for(int i=2; i<=len; i*=2) {
		Complex wn=Complex(cos(2*pi/i),sin(flag*2*pi/i));
		for(int j=0; j<len; j+=i) {
			Complex w=Complex(1,0);
			for(int k=j; k<j+i/2; k++) {
				Complex l=x[k],r=w*x[k+i/2];
				x[k]=l+r,x[k+i/2]=l-r;
				w=w*wn;
			}
		}
	}
	if(flag==-1) for(int i=0; i<len; i++) x[i].x/=len;
}
struct poly {
	mutable vector<Complex> x;
	void read(int n) {x.resize(n);for(int i=0; i<n; i++) {int a;cin>>a;x[i]=Complex(a,0);}}
	void write(int n=-1) {if(!~n) n=x.size();for(int i=0; i<n; i++) cout<<(int)round(x[i].x)<<" ";cout<<"\n";}
	poly operator*(const poly &o) const {
		poly a=*this,b=o;
		int len=1,tmp=a.x.size()+b.x.size()-1;
		while(len<tmp) len<<=1;
		a.x.resize(len),b.x.resize(len);
		fft(a.x,len,1),fft(b.x,len,1);
		for(int i=0; i<len; i++) a.x[i]=a.x[i]*b.x[i];
		fft(a.x,len,-1);
		a.x.resize(tmp);
		return a;
	}
};
posted @ 2025-02-04 17:07  System_Error  阅读(22)  评论(0)    收藏  举报