[2019.3.25]多项式求逆

多项式求逆是什么

对于一个\(n\)次多项式\(F(x)\),要求一个小于等于\(n\)次的多项式\(G(x)\),满足

\(F(x)G(x)\equiv1(mod\ x^n)\)

\(mod\ x^n\)即只考虑所有多项式的前n项。

怎么做多项式求逆

显然,当\(F(x)\)次数为0,即只有常数项时,它的逆元就是常数项的逆元。

对于次数大于0的多项式我们假设我们已经递归求出\(F(x)\)\(mod\ x^{\lceil\frac{n}{2}\rceil}\)意义下的逆\(H(x)\)

也就是我们有

\(F(x)H(x)\equiv1(mod\ x^{\lceil\frac{n}{2}\rceil})\)

\(F(x)G(x)\equiv1(mod\ x^n)\)易知\(F(x)G(x)\equiv1(mod\ x^{\lceil\frac{n}{2}\rceil})\)

两式相减得

\(F(x)[H(x)-G(x)]\equiv0(mod\ x^{\lceil\frac{n}{2}\rceil})\)

我们有\(F(x)\not=0\),即\(H(x)-G(x)\equiv0(mod\ x^{\lceil\frac{n}{2}\rceil})\)

由于我们有若\(a\equiv b(mod\ p)\),则\(a^2\equiv b^2(mod\ x^2)\)

则两边平方得

\(H(x)^2-2G(x)H(x)+G(x)^2\equiv0(mod\ x^{2\times\lceil\frac{n}{2}\rceil})\)

.因为\(2\times\lceil\frac{x}{2}\rceil\ge n\)所以\(H(x)^2-2G(x)H(x)+G(x)^2\equiv0(mod\ x^n)\)

两边同乘\(F(x)\),由于\(F(x)G(x)\equiv1(mod\ x^n)\)

\(F(x)H(x)^2-2H(x)+G(x)\equiv0(mod\ x^n)\)

\(G(x)\equiv 2H(x)-F(x)H(x)^2(mod\ x^n)\)

\(G(x)\equiv H(x)[2-F(x)H(x)](mod\ x^n)\)

我们可以NTT实现多项式乘法,时间复杂度\(O(n\log^2n)\)

code:

#include<bits/stdc++.h>
#define ci const int&
#define VAL(p,n,i) (i<n?p[i]:0)
using namespace std;
const int mod=998244353;
const int g=3;
int cpy[600010];
int POW(int x,int y){
	int tot=1;
	while(y)y&1?tot=1ll*tot*x%mod:0,x=1ll*x*x%mod,y>>=1;
	return tot;
}
void NTT(vector<int>&f,ci l,ci len,ci op){
	if(len&1)return;
	for(int i=l;i<l+len;++i)cpy[i]=f[i];
	int nw=l-1,ln=len>>1;
	for(int i=l;i<l+ln;++i)f[i]=cpy[++nw],f[i+ln]=cpy[++nw];
	NTT(f,l,ln,op),NTT(f,l+ln,ln,op);
	int rt=POW(g,(mod-1)/len),t;
	op?rt=POW(rt,mod-2):0;
	nw=1;
	for(int i=l;i<l+len;++i)cpy[i]=f[i];
	for(int i=l;i<l+ln;++i,nw=1ll*nw*rt%mod)t=1ll*nw*cpy[i+ln]%mod,f[i]=(cpy[i]+t)%mod,f[i+ln]=(cpy[i]-t+mod)%mod;
}
vector<int>F;
vector<int>T;
vector<int>tmp;
vector<int>a;
vector<int>b;
vector<int>c;
int ts,sz,tg,inv;
void print(const vector<int>&x){
	for(int i=0;i<x.size();++i)printf("%d ",x[i]);
}
vector<int>calc(const vector<int>&x,const vector<int>&y){//2y-x*y^2
	ts=x.size()+y.size()+y.size()-2,sz=1,a.clear(),b.clear(),c.clear();
	while(sz<ts)sz<<=1;
	for(int i=0;i<x.size();++i)a.push_back(x[i]);
	for(int i=0;i<y.size();++i)b.push_back(y[i]);
	a.resize(sz),b.resize(sz),NTT(a,0,sz,0),NTT(b,0,sz,0);
	for(int i=0;i<sz;++i)c.push_back((2-1ll*a[i]*b[i]%mod+mod)%mod*b[i]%mod);
	NTT(c,0,sz,1),inv=POW(sz,mod-2);
	for(int i=0;i<sz;++i)c[i]=1ll*c[i]*inv%mod;
	return c;
}
int n,v;
vector<int>INV(const vector<int>&x){
	if(x.size()==1)return T.resize(1),T[0]=POW(x[0],mod-2),T;
	vector<int>G=x;
	G.resize((x.size()+1)>>1),G=calc(x,INV(G)),G.resize(x.size());
	return G;
}
int main(){
	scanf("%d",&n);
	for(int i=0;i<n;++i)scanf("%d",&v),F.push_back(v);
	print(INV(F));
	return 0;
}
posted @ 2019-03-25 11:32  xryjr233  阅读(287)  评论(0编辑  收藏  举报