多项式全家桶学习笔记
一、数组版本
数组版本和 poly 版本都只涵盖目录中第 \(4\sim 10\) 部分。
namespace Poly
{
int p[maxn],q[maxn],r[maxn],w[maxn];
int inum[maxn];
int qpow(int a,int k)
{
int res=1;
for(;k;a=1ll*a*a%mod,k>>=1) if(k&1) res=1ll*res*a%mod;
return res;
}
int add(int x,int y)
{
return x+y>=mod?x+y-mod:x+y;
}
int dec(int x,int y)
{
return x-y<0?x-y+mod:x-y;
}
int extend(int n)
{
return n!=1?1<<(__lg(n-1)+1):1;
}
void get_r(int n)
{
for(int i=0;i<n;i++) r[i]=(r[i>>1]>>1)|(i&1?n>>1:0);
}
static auto init=[]()
{
for(int k=2,m=1;k<=maxn;k<<=1,m<<=1)
{
w[m]=1;
for(int i=m+1,x=qpow(3,(mod-1)/k);i<k;i++) w[i]=1ll*w[i-1]*x%mod;
}
for(int i=1;i<maxn;i++) inum[i]=qpow(i,mod-2);
return 0;
}();
void print(int *a,int n)
{
for(int i=0;i<n;i++) printf("%d%c",a[i]," \n"[i==n-1]);
}
void ntt(int *a,int n,int op)
{
for(int i=0;i<n;i++) if(i<r[i]) swap(a[i],a[r[i]]);
for(int k=2,m=1;k<=n;k<<=1,m<<=1)
for(int i=0;i<n;i+=k)
for(int j=i,*x=w+m;j<i+m;j++,x++)
{
int v=1ll*a[j+m]**x%mod;
a[j+m]=dec(a[j],v),a[j]=add(a[j],v);
}
if(op==-1)
{
reverse(a+1,a+n);
for(int i=0,v=qpow(n,mod-2);i<n;i++) a[i]=1ll*a[i]*v%mod;
}
}
void mul(int *a,int *b,int n,int m)
{
if(!n||!m) return ;
int len=extend(n+m-1);
for(int i=0;i<len;i++) p[i]=i<n?a[i]:0,q[i]=i<m?b[i]:0;
get_r(len),ntt(p,len,1),ntt(q,len,1);
for(int i=0;i<len;i++) a[i]=1ll*p[i]*q[i]%mod;
ntt(a,len,-1);
}
void inv(int *a,int *b,int n)
{
static int c[maxn],d[maxn];
n=extend(n),memset(b,0,4*n),b[0]=qpow(a[0],mod-2);
for(int k=2;k<=n;k<<=1)
{
for(int i=0;i<k<<1;i++) c[i]=i<k?a[i]:0,d[i]=i<k>>1?b[i]:0;
get_r(k),ntt(d,k,1);
for(int i=0;i<k;i++) d[i]=1ll*d[i]*d[i]%mod;
ntt(d,k,-1),get_r(k<<1),ntt(c,k<<1,1),ntt(d,k<<1,1);
for(int i=0;i<k<<1;i++) c[i]=1ll*c[i]*d[i]%mod;
ntt(c,k<<1,-1);
for(int i=0;i<k;i++) b[i]=(2ll*b[i]-c[i]+mod)%mod;
}
}
void diff(int *a,int *b,int n)
{
for(int i=1;i<n;i++) b[i-1]=1ll*i*a[i]%mod;
b[n-1]=0;
}
void integ(int *a,int *b,int n)
{
for(int i=1;i<n;i++) b[i]=1ll*inum[i]*a[i-1]%mod;
b[0]=0;
}
void ln(int *a,int *b,int n)
{
static int c[maxn],d[maxn];
assert(a[0]==1);
n=extend(n),inv(a,c,n),diff(a,d,n),mul(c,d,n,n),integ(c,b,n);
}
void exp(int *a,int *b,int n)
{
static int c[maxn];
assert(a[0]==0);
n=extend(n),memset(b,0,4*n),memset(c,0,4*n),b[0]=1;
for(int k=2;k<=n;k<<=1)
{
ln(b,c,k);
for(int i=0;i<k;i++) c[i]=dec(a[i],c[i]);
c[0]++,mul(b,c,k,k);
}
}
void sqrt(int *a,int *b,int n)
{
static const int inv2=(mod+1)>>1;
static int c[maxn],d[maxn];
assert(a[0]==1);
n=extend(n),memset(b,0,4*n),b[0]=1;
for(int k=2;k<=n;k<<=1)
{
memcpy(c,a,4*k),inv(b,d,k),mul(c,d,k,k);
for(int i=0;i<k;i++) b[i]=1ll*(b[i]+c[i])*inv2%mod;
}
}
/** enough for k<mod
void pow(int *a,int *b,int n,int k)
{
static int c[maxn];
memcpy(c,a,4*n),memset(b,0,4*n),b[0]=1;
for(;k;mul(c,c,n,n),k>>=1) if(k&1) mul(b,c,n,n);
}
**/
void pow(int *a,int *b,int n,string k)
{
static int c[maxn],d[maxn];
int u=0,k1=0,k2=0;
memset(b,0,4*n);
for(int i=0;i<k.size();i++) k1=(10ll*k1+k[i]-'0')%mod,k2=(10ll*k2+k[i]-'0')%(mod-1);
for(u=0;u<n&&!a[u];u++) ;
if((u&&k.size()>=5)||1ll*u*k1>=n) return ;
for(int i=u,x=qpow(a[u],mod-2);i<n;i++) c[i-u]=1ll*a[i]*x%mod;
ln(c,d,n-u);
for(int i=0;i<n-u;i++) d[i]=1ll*d[i]*k1%mod;
exp(d,c,n-u);
for(int i=u*k1,x=qpow(a[u],k2);i<n;i++) b[i]=1ll*c[i-u*k1]*x%mod;
}
void div(int *a,int *b,int *q,int *r,int n,int m)
{///len(q)=n-m+1,len(r)<=m-1
static int c[maxn],d[maxn],e[maxn];
int len=n-m+1;
for(int i=0;i<len;i++) c[i]=a[n-1-i];
for(int i=0;i<len;i++) d[i]=i<m?b[m-1-i]:0;
inv(d,e,len),mul(c,e,len,len),reverse(c,c+len),memcpy(q,c,4*len),mul(c,b,len,m);
for(int i=0;i<m;i++) r[i]=dec(a[i],c[i]);
}
}
温馨提示:
- 调用
Poly::init()函数进行初始化。 - 长度上限
maxn一般取extend(n<<1),如 \(n=10^5\) 时,maxn=1<<18。 - 系数不能有负数,否则
add,dec失效后ntt函数会爆int。 - 传参数组长度类似
vector的size函数, \(\deg=n-1\) 。 - 多项式快速幂更推荐使用注释中的方法,双 \(\log\) 用时仅为单 \(\log\) 的三倍,但代码简洁许多。
二、vector 版本
namespace Poly
{
int r[maxn],w[maxn],inum[maxn];
int qpow(int a,int k)
{
int res=1;
for(;k;a=1ll*a*a%mod,k>>=1) if(k&1) res=1ll*res*a%mod;
return res;
}
int add(int x,int y)
{
return x+y>=mod?x+y-mod:x+y;
}
int dec(int x,int y)
{
return x-y<0?x-y+mod:x-y;
}
int extend(int n)
{
return n!=1?1<<(__lg(n-1)+1):1;
}
void get_r(int n)
{
for(int i=0;i<n;i++) r[i]=(r[i>>1]>>1)|(i&1?n>>1:0);
}
static auto init=[]()
{
for(int k=2,m=1;k<=maxn;k<<=1,m<<=1)
{
w[m]=1;
for(int i=m+1,x=qpow(3,(mod-1)/k);i<k;i++) w[i]=1ll*w[i-1]*x%mod;
}
for(int i=1;i<maxn;i++) inum[i]=qpow(i,mod-2);
return 0;
}();
void print(poly a,int n)
{
a.resize(n);
for(int i=0;i<n;i++) printf("%d%c",a[i]," \n"[i==n-1]);
}
void ntt(poly &a,int n,int op)
{
for(int i=0;i<n;i++) if(i<r[i]) swap(a[i],a[r[i]]);
for(int k=2,m=1;k<=n;k<<=1,m<<=1)
for(int i=0;i<n;i+=k)
for(int j=i,*x=w+m;j<i+m;j++,x++)
{
int v=1ll*a[j+m]**x%mod;
a[j+m]=dec(a[j],v),a[j]=add(a[j],v);
}
if(op==-1)
{
reverse(a.begin()+1,a.begin()+n);
for(int i=0,v=qpow(n,mod-2);i<n;i++) a[i]=1ll*a[i]*v%mod;
}
}
poly operator*(poly a,int b)
{
for(int i=0;i<a.size();i++) a[i]=1ll*a[i]*b%mod;
return a;
}
poly operator+(poly a,poly b)
{
int n=max(a.size(),b.size());a.resize(n),b.resize(n);
for(int i=0;i<n;i++) a[i]=add(a[i],b[i]);
return a;
}
poly operator-(poly a,poly b)
{
int n=max(a.size(),b.size());a.resize(n),b.resize(n);
for(int i=0;i<n;i++) a[i]=dec(a[i],b[i]);
return a;
}
poly operator*(poly a,poly b)
{
if(a.empty()||b.empty()) return {};
int n=a.size(),m=b.size(),len=extend(n+m-1);
a.resize(len),b.resize(len),get_r(len),ntt(a,len,1),ntt(b,len,1);
for(int i=0;i<len;i++) a[i]=1ll*a[i]*b[i]%mod;
return ntt(a,len,-1),a.resize(n+m-1),a;
}
void operator*=(poly &a,int b)
{
a=a*b;
}
void operator+=(poly &a,poly b)
{
a=a+b;
}
void operator-=(poly &a,poly b)
{
a=a-b;
}
void operator*=(poly &a,poly b)
{
a=a*b;
}
poly inv(poly a,int n)
{
n=extend(n),a.resize(n);
poly b(n);b[0]=qpow(a[0],mod-2);
for(int k=2;k<=n;k<<=1)
{
poly c(a.begin(),a.begin()+k),d(b.begin(),b.begin()+k);
get_r(k),ntt(d,k,1);
for(int i=0;i<k;i++) d[i]=1ll*d[i]*d[i]%mod;
ntt(d,k,-1),c*=d;
for(int i=0;i<k;i++) b[i]=(2ll*b[i]-c[i]+mod)%mod;
}
return b;
}
poly diff(poly a,int n)
{
poly b(n);
for(int i=1;i<n;i++) b[i-1]=1ll*i*a[i]%mod;
return b[n-1]=0,b;
}
poly integ(poly a,int n)
{
poly b(n);
for(int i=1;i<n;i++) b[i]=1ll*inum[i]*a[i-1]%mod;
return b[0]=0,b;
}
poly ln(poly a,int n)
{
assert(a[0]==1);
n=extend(n),a.resize(n);
return integ(inv(a,n)*diff(a,n),n);
}
poly exp(poly a,int n)
{
assert(a[0]==0);
n=extend(n),a.resize(n);
poly b={1,0};
for(int k=2;k<=n;k<<=1)
{
poly c=ln(b,k);
for(int i=0;i<k;i++) c[i]=dec(a[i],c[i]);
c[0]++,b=b*c;
}
return b;
}
poly sqrt(poly a,int n)
{
static const int inv2=(mod+1)>>1;
assert(a[0]==1);
n=extend(n),a.resize(n);
poly b(n);b[0]=1;
for(int k=2;k<=n;k<<=1)
{
poly c(a.begin(),a.begin()+k);c*=inv(b,k);
for(int i=0;i<k;i++) b[i]=1ll*(b[i]+c[i])*inv2%mod;
}
return b;
}
poly operator<<(poly a,int k)
{
poly b(a.size()+k);
for(int i=0;i<a.size();i++) b[i+k]=a[i];
return b;
}
poly operator>>(poly a,int k)
{
if(a.size()<=k) return {0};
poly b(a.size()-k);
for(int i=k;i<a.size();i++) b[i-k]=a[i];
return b;
}
void operator<<=(poly &a,int k)
{
a=a<<k;
}
void operator>>=(poly &a,int k)
{
a=a>>k;
}
poly pow(poly a,int n,string k)
{
int u=0,k1=0,k2=0;
for(int i=0;i<k.size();i++) k1=(10ll*k1+k[i]-'0')%mod,k2=(10ll*k2+k[i]-'0')%(mod-1);
for(u=0;u<n&&!a[u];u++) ;
if((u&&k.size()>5)||1ll*u*k1>=n) return poly(n);
poly b(n),c(n-u);
for(int i=u,x=qpow(a[u],mod-2);i<n;i++) c[i-u]=1ll*a[i]*x%mod;
c=ln(c,n-u);
for(int i=0;i<n-u;i++) c[i]=1ll*c[i]*k1%mod;
c=exp(c,n-u);
for(int i=u*k1,x=qpow(a[u],k2);i<n;i++) b[i]=1ll*c[i-u*k1]*x%mod;
return b;
}
pair<poly,poly> div(poly a,poly b,int n,int m)
{///len(q)=n-m+1,len(r)<=m-1
a.resize(n),b.resize(m);
if(n<m) return make_pair(poly{0},a);
poly c=a,d=b;reverse(c.begin(),c.end()),c.resize(n-m+1),reverse(d.begin(),d.end());
poly q=c*inv(d,n-m+1);q.resize(n-m+1),reverse(q.begin(),q.end());
poly r=a-q*b;r.resize(m-1);
return make_pair(q,r);
}
poly operator/(poly a,poly b)
{
return div(a,b,a.size(),b.size()).first;
}
poly operator%(poly a,poly b)
{
return div(a,b,a.size(),b.size()).second;
}
void operator/=(poly &a,poly b)
{
a=a/b;
}
void operator%=(poly &a,poly b)
{
a=a%b;
}
}
using namespace Poly;
温馨提示:
-
调用
Poly::init()函数进行初始化。 -
如果不涉及到分治 \(\texttt{NTT}\) ,个人不太推荐使用
vector版本的板子。vector版本代码长度比数组版本高 \(20\%\) ,运行效率比数组版本低 \(10\%\) ,很大程度上是因为构造新的vector时间开销较大。vector版本由于涉及到重载运算符操作,所以不得不将Poly命名空间放入全局空间,但这很有可能导致变量重名。
三、快速傅里叶变换(FFT)
\(\texttt{FFT}\) 的核心思想:系数多项式转点值多项式(\(\texttt{DFT}\)),点值多项式可以直接相乘,点值多项式再转系数多项式(\(\texttt{IDFT}\))。
暴力求点值时间复杂度依然是 \(\mathcal O(n^2)\) ,但是如果将 \(n\) 补成 \(2\) 的方幂,并且利用单位根的性质,我们可以做到 \(\mathcal O(n\log n)\) 。
记 \(w_n^k=\cos\frac{2k\pi}n+i\sin\frac{2k\pi}n\) ,目标求 \(f(w_n^0),\cdots,f(w_n^{n-1})\) 。
令:
容易发现 \(f(x)=f_1(x^2)+x\cdot f_2(x^2)\) ,对 \(0\le k\lt\frac n2\) ,分别代入 \(w_n^k\) 和 \(w_n^{k+\frac n2}\) :
时间复杂度 \(T(n)=2\cdot T(\frac n2)+\mathcal O(n)\) ,解得 \(T(n)=\mathcal O(n\log n)\) 。
递归版本常数太大,我们希望改成迭代版本。
观察一下使用 \(a_k\) 的次序,每次我们将偶数项(二进制下末位为 \(0\))放在前面,将奇数项(二进制下末位为 \(1\))放在后面。
设 \(2^d=n\) ,那么 \(d-1\) 轮操作等价于将 \(a_k\) 的二进制位翻转。
对 \(\forall 2\le i\le d\) ,维护自底向上迭代 \(i\) 次后的结果,每次通过 \(f_1(w_{2^{i-1}}^k)\) 和 \(f_2(w_{2^{i-1}}^k)\) 计算 \(f(w_{2^i}^k)\) 和 \(f(w_{2^i}^{k+2^{i-1}})\) 。
上述过程也被称为蝴蝶变换。
记 \(w=w_n\) , \(\texttt{DFT}\) 的本质是:
读者可以验证转移矩阵的逆矩阵为:
于是有了 \(\texttt{IDFT}\) 的第一种方法:用 \(w^{-1}\) 代替 \(w\) 重新做一遍 \(\texttt{FFT}\) ,最后除以 \(n\) 。
void fft(complex<double> *a,int n,int op)
{
for(int i=0;i<n;i++) if(i<r[i]) swap(a[i],a[r[i]]);
for(int k=2,m=1;k<=n;k<<=1,m<<=1)
{
complex<double> x(cos(2*pi/k),op*sin(2*pi/k)),w(1,0),v;
for(int i=0;i<n;i+=k,w=1)
for(int j=i;j<i+m;j++)
v=w*a[j+m],a[j+m]=a[j]-v,a[j]=a[j]+v,w*=x;
}
if(op==-1) for(int i=0;i<n;i++) a[i]/=n;
}
但是上面这种方法在 \(\texttt{NTT}\) 中不利于卡常,究其原因是我们要同时预处理 \(w\) 和 \(w^{-1}\) 的方幂。
于是有了第二种看待 \(\texttt{DFT}\) 的视角:
根据上面的分析,可以得到:
如果将 \(g(x)=\texttt{IDFT}[f(x)]\) 视为 \(n\) 次多项式(原本是 \(n-1\) 次),则:
对比两边每一项的系数:
从而避开 \(w^{-1}\) ,完美。
void fft(complex<double> *a,int n,int op)
{
for(int i=0;i<n;i++) if(i<r[i]) swap(a[i],a[r[i]]);
for(int k=2,m=1;k<=n;k<<=1,m<<=1)
{
complex<double> x(cos(2*pi/k),sin(2*pi/k)),w(1,0),v;
for(int i=0,m=k>>1;i<n;i+=k,w=1)
for(int j=i;j<i+m;j++)
v=w*a[j+m],a[j+m]=a[j]-v,a[j]=a[j]+v,w*=x;
}
if(op==-1)
{
reverse(a+1,a+n);
for(int i=0;i<n;i++) a[i]/=n;
}
}
四、快速数论变换(NTT)
\(\texttt{NTT}\) 使用条件: \(n\le 2^{v_2(p-1)}\) 。
单位根的性质原根全部满足,只需将 \(w_n\) 换成 \(g_n\) 即可。
卡常的重点在于预处理每一轮蝴蝶变换时 \(w\) 的值(共 \(m\) 个),存放在 w[m]~w[k-1] 的位置,这样可以节约 \(\mathcal O(n\log n)\) 次乘法。
从这里到第十部分都不再展示源码,请到第一、二部分自行查看。
五、多项式求逆(inv)
题目描述:给定 \(f(x)\) ,求 \(g(x)\) 满足 \(f(x)\cdot g(x)\equiv 1\pmod{x^n}\) 。
考虑倍增,假设已知 \(f(x)\cdot g_0(x)\equiv 1\pmod{x^\frac n2}\) 。
六、多项式对数函数(ln)
题目描述:给定 \(f(x)\) 满足常数项为 \(1\) ,求 \(g(x)\) 满足 \(g(x)\equiv\ln f(x)\pmod{x^n}\) 。
两边对 \(x\) 求导:
求导求逆乘起来,再积分回去,即可得到 \(g(x)\) 。
七、多项式指数函数(exp)
题目描述:给定 \(f(x)\) 满足常数项为 \(0\) ,求 \(g(x)\) 满足 \(g(x)\equiv e^{f(x)}\pmod{x^n}\) 。
前置知识:牛顿迭代。
牛顿迭代用于求函数零点。
任取 \(x_0\) 作为初始值,取 \(x_1=x_0-\frac{f(x_0)}{f'(x_0)}\) 为 \(x_0\) 处的切线与 \(x\) 轴的交点,依此类推。
牛顿迭代收敛速度非常快,对于多项式,每做一次精度就会加倍。
回到原题,题目等价于求 \(F(g(x))=\ln g(x)-f(x)\) 的零点,这里 \(f(x)\) 为常数, \(g(x)\) 为变量。
考虑倍增,假设 \(g_0(x)\) 为 \(\bmod x^\frac n2\) 下的零点,则:
八、多项式开根(sqrt)
题目描述:给定 \(f(x)\) 满足 常数项为 \(1\) ,求 \(g(x)\) 满足 \(g^2(x)\equiv f(x)\pmod{x^n}\) 。
考虑倍增,假设 \(g_0^2(x)\equiv f(x)\pmod{x^\frac n2}\) 。
九、多项式快速幂(pow)
题目描述:给定 \(f(x)\) ,求 \(g(x)\) 满足 \(g(x)\equiv f(x)^k\pmod{x^n}\) 。普通版保证 \(a_0=1\) ,加强版无此限制。
方法一
指数 \(k\) 可对 \(p\) 取模,从而保证 \(k\lt p\) 。
\(\text Q\) :为什么当 \(a_0=1\) 且 \(n\lt p\) 时,有 \(f(x)^k\pmod{x^n}=f(x)^{k\bmod p}\pmod{x^n}\) ?
\(\text A\) :只需证明 \(f(x)^p\equiv 1\pmod{x^n}\) 。
事实上,对于 \(f(x)^p\) 展开式中的每一项 \(\binom p{j_0,\cdots,j_{n-1}}\prod_{i=0}^{n-1}(a_ix^i)^{j_i}\) ,只要 \(j_i\) 中有超过两项非零,则前面的组合数系数是 \(p\) 的倍数,可以直接扔掉。
因此 \(f(x)^p\equiv\sum_{i=0}^{n-1}(a_ix^i)^p\pmod{x^n}\) ,其中第一项为 \(1\) ,后面每一项次数都超过了 \(n\) ,可以直接扔掉。
类比普通快速幂,时间复杂度 \(\mathcal O(n\log n\log p)\) ,但常数小。
方法二
如果 \(f(x)\) 常数项为 \(1\) ,先取 \(\ln\) ,乘上 \(k\) 倍后再 \(\exp\) 即可。
如果 \(f(x)\) 常数项不为 \(1\) ,设 \(f(x)\) 的最低次项为 \(a_mx^m\) ,则:
然后转化为普通版。
温馨提示:如果 \(p\ge n\),那么计算 \(a_m^k\) 时 \(k\) 对 \(\varphi(p)\) 取模,计算 \(f(x)^k\) 时 \(k\) 对 \(p\) 取模。
十、多项式除法 & 取模(div)
给定 \(f(x),g(x)\) 满足 \(\deg f=n\ge m=\deg g\) ,求 \(q(x),r(x)\) 满足 \(\deg q=n-m,\deg r\le m-1\) ,且 \(f(x)=q(x)\cdot g(x)+r(x)\) 。
对于次数为 \(k\) 的多项式 \(A(x)\) ,记 \(A_R(x)=x^kA(\frac 1x)\) 。
为方便起见,记 \(\deg r=m-1\) 。
至此我们可以求出 \(q(x)\) ,最后令 \(r(x)=f(x)-q(x)\cdot g(x)\) 即可。
温馨提示:求逆是在 \(\bmod x^{n-m+1}\) 而不是 \(x^m\) 意义下进行。
十一、分治 FFT
半在线卷积
题目描述:给定多项式 \(g\) ,初始值 \(f_0=1\) , \(f_i=\sum_{j=1}^if_{i-j}g_j\) ,求 \(f_0\sim f_{n-1}\) 。
由于初始 \(f\) 未知,常规的 \(\texttt{FFT/NTT}\) 无法处理,我们只能用算过 \(f\) 来更新后面的 \(f\) ,这也是名字 "半在线卷积" 的由来。
考虑 \(\texttt{cdq}\) 分治,递归完左区间后我们已经获取 \(f[l\sim mid]\) 的真实值,然后执行 \(f[l\sim mid]\cdot g[1\sim r-l]\to f[mid+1,r]\) ,再递归右区间即可。
注意到 \(j\) 对 \(i\) 的贡献会在 \(j,i\) 分别落在左、右子区间时计算,可以保证不重不漏。
时间复杂度 \(\mathcal O(n\log^2n)\) ,不过 \(\texttt{cdq}\) 分治的常数很小。
#include<bits/stdc++.h>
using namespace std;
const int maxn=1<<18,mod=998244353;
int n;
int c[maxn],f[maxn],g[maxn];
int p[maxn],q[maxn],r[maxn],w[maxn];
int qpow(int a,int k)
{
int res=1;
for(;k;a=1ll*a*a%mod,k>>=1) if(k&1) res=1ll*res*a%mod;
return res;
}
int add(int x,int y)
{
return x+y>=mod?x+y-mod:x+y;
}
int dec(int x,int y)
{
return x-y<0?x-y+mod:x-y;
}
int extend(int n)
{
return n!=1?1<<(__lg(n-1)+1):1;
}
void get_r(int n)
{
for(int i=0;i<n;i++) r[i]=(r[i>>1]>>1)|(i&1?n>>1:0);
}
static auto init=[]()
{
for(int k=2,m=1;k<=maxn;k<<=1,m<<=1)
{
w[m]=1;
for(int i=m+1,x=qpow(3,(mod-1)/k);i<k;i++) w[i]=1ll*w[i-1]*x%mod;
}
return 0;
}();
void ntt(int *a,int n,int op)
{
for(int i=0;i<n;i++) if(i<r[i]) swap(a[i],a[r[i]]);
for(int k=2,m=1;k<=n;k<<=1,m<<=1)
for(int i=0;i<n;i+=k)
for(int j=i,*x=w+m;j<i+m;j++,x++)
{
int v=1ll*a[j+m]**x%mod;
a[j+m]=dec(a[j],v),a[j]=add(a[j],v);
}
if(op==-1)
{
reverse(a+1,a+n);
for(int i=0,v=qpow(n,mod-2);i<n;i++) a[i]=1ll*a[i]*v%mod;
}
}
void mul(int *a,int *b,int *c,int n,int m)
{
if(!n||!m) return ;
int len=extend(n+m-1);
for(int i=0;i<len;i++) p[i]=i<n?a[i]:0,q[i]=i<m?b[i]:0;
get_r(len),ntt(p,len,1),ntt(q,len,1);
for(int i=0;i<len;i++) c[i]=1ll*p[i]*q[i]%mod;
ntt(c,len,-1);
}
void cdq(int l,int r)
{
if(l==r) return ;
int mid=(l+r)>>1;
cdq(l,mid),mul(f+l,g+1,c,mid-l+1,r-l);
for(int i=mid+1;i<=r;i++) f[i]=add(f[i],c[i-l-1]);
cdq(mid+1,r);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<n;i++) scanf("%d",&g[i]);
f[0]=1,cdq(0,n-1);
for(int i=0;i<n;i++) printf("%d%c",f[i]," \n"[i==n-1]);
return 0;
}
多个低次式相乘
假如若干个多项式次数之和为 \(n\) ,我们可以在 \(\mathcal O(n\log^2 n)\) 的时间内求出它们的乘积。
第一种做法类似于线段树 pushup 函数。对每个区间 \([l,r]\) 维护第 \(l\sim r\) 个多项式乘积,由于线段树只有 \(\log n\) 层,因此涉及到的多项式次数之和(即占用空间大小)为 \(\mathcal O(n\log n)\) 级别。
void build(int p,int l,int r)
{
if(l==r) return ;
int mid=(l+r)>>1;
build(ls,l,mid),build(rs,mid+1,r);
pushup(p);
}
第二种做法类似 Huffman 树。每次选取两个次数最小的多项式相乘,然后重新塞进堆中。由启发式合并的时间复杂度计算方法可知,涉及到的多项式次数之和为 \(\mathcal O(n\log n)\) 级别。
bool operator<(const poly &a,const poly &b)
{
return a.size()>b.size();
}
priority_queue<poly> q;
poly work()
{
while(q.size()>=2)
{
poly a=q.top();q.pop();
poly b=q.top();q.pop();
q.push(a*b);
}
return q.top();
}
一般情况下,如果给的都是一次式(实战大部分都是这种情况),那么更推荐用第一种,不仅省去了堆的开销,保留的区间乘积说不定以后还用得到;如果次数不统一且只需要全局乘积,那么更推荐第二种。
十二、多项式多点求值
给定 \(n\) 次多项式 \(f(x)\) 和 \(m\) 个点值 \(a_i\) ,求 \(f(a_i)\) 。
注意到 \(f(a_i)=f(x)\bmod(x-a_i)\) ,第一遍分治预处理 \(G_{l,r}(x)=\prod_{i=l}^r(x-a_i)\) ,第二遍分治计算 \(F_{l,r}(x)\bmod G_{l,r}(x)\) 的结果并向下递归。
时间复杂度 \(\mathcal O(n\log^2n)\) ,但是用到了多项式取模,常数较大。
#include<bits/stdc++.h>
#define poly vector<int>
#define ls p<<1
#define rs p<<1|1
using namespace std;
const int maxn=1<<17,mod=998244353;
int m,n;
int a[maxn];
poly f,g[maxn];
///为节约篇幅,此处省略 namespace Poly
void build(int p,int l,int r)
{
if(l==r) return g[p]={mod-a[l],1},void();
int mid=(l+r)>>1;
build(ls,l,mid),build(rs,mid+1,r);
g[p]=g[ls]*g[rs];
}
void solve(int p,int l,int r,poly f)
{
if(l==r) return printf("%d\n",f[0]),void();
int mid=(l+r)>>1;
solve(ls,l,mid,f%g[ls]),solve(rs,mid+1,r,f%g[rs]);
}
int main()
{
scanf("%d%d",&n,&m),f.resize(n+1);
for(int i=0;i<=n;i++) scanf("%d",&f[i]);
for(int i=1;i<=m;i++) scanf("%d",&a[i]);
build(1,1,m),solve(1,1,m,f);
return 0;
}
十三、多项式快速插值
题目描述:给定 \(n\) 个点 \((x_i,y_i)\) ,求 \(n-1\) 次多项式 \(f(x)\) ,满足 \(f(x_i)\equiv y_i\pmod{998244353}\) 。
根据拉格朗日插值公式,我们有:
令 \(g(x)=\prod_{j=1}^n(x-x_j)\) ,由洛必达法则:
用分治 \(\texttt{NTT}\) 求出 \(g\) ,然后多项式求导,再多项式多点求值即可求出 \(g'(x_i)\) 。
将 \(\frac{y_i}{g'(x_i)}\) 视为仅和 \(i\) 有关的常数 \(v_i\) ,则:
还是考虑分治,对区间 \([l,r]\) 维护:
更新时,只需用左区间的 \(f\) 乘以右区间 \(x-x_j\) 的连乘积,再加上右区间的 \(f\) 乘以左区间 \(x-x_j\) 的连乘积即可,注意 \(\prod_{j=l}^r(x-x_j)\) 在第一遍预处理 \(g\) 时就已经求过了。
时间复杂度 \(\mathcal O(n\log^2n)\) 。
#include<bits/stdc++.h>
#define poly vector<int>
#define ls p<<1
#define rs p<<1|1
using namespace std;
const int maxn=1<<18,mod=998244353;
int n,m;
int x[maxn],y[maxn];
poly f[maxn],g[maxn];
///为节约篇幅,此处省略 namespace Poly
void dfs1(int p,int l,int r)
{
if(l==r) return g[p]={mod-x[l],1},void();
int mid=(l+r)>>1;
dfs1(ls,l,mid),dfs1(rs,mid+1,r);
g[p]=g[ls]*g[rs];
}
void dfs2(int p,int l,int r,poly a)
{
if(l==r) return y[l]=1ll*y[l]*qpow(a[0],mod-2)%mod,void();
int mid=(l+r)>>1;
dfs2(ls,l,mid,a%g[ls]),dfs2(rs,mid+1,r,a%g[rs]);
}
void dfs3(int p,int l,int r)
{
if(l==r) return f[p]={y[l]},void();
int mid=(l+r)>>1;
dfs3(ls,l,mid),dfs3(rs,mid+1,r);
f[p]=f[ls]*g[rs]+f[rs]*g[ls];
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d%d",&x[i],&y[i]);
dfs1(1,1,n);
dfs2(1,1,n,diff(g[1],n+1));
dfs3(1,1,n);
print(f[1],n);
return 0;
}
本文来自博客园,作者:peiwenjun,转载请注明原文链接:https://www.cnblogs.com/peiwenjun/p/19068538
浙公网安备 33010602011771号