拉格朗日插值求系数
更新(2024.9.8):更新了格式。
设有 \(n\) 个点,坐标为 \((x_i,y_i)\), 现在要求解它们所够成的 \(n-1\) 次多项式 \(F(x)\) 的系数。
先回顾一下一般拉格朗日插值:
定义
\[f_i(x)=\begin{cases}1,(x=x_i)\\0,(x=x_j,j\neq i)\end{cases}\\
y_i=F(x_i)
\]
\(F(x)\) 必须满足代入任意一个 \(x_i\), 得到一个对应的 \(y_i\)。
因此
\[F(x)=\sum_{i=1}^{n}y_i\cdot f_i(x)
\]
可以通过构造得
\[f_i(x)=\prod_{j=1,j\neq i}^{n}\frac{x-x_j}{x_i-x_j}
\]
那么
\[F(x)=\sum_{i=1}^{n}y_i\cdot f_i(x)=\sum_{i=1}^{n}y_i\cdot \prod_{j=1,j\neq i}^{n}\frac{x-x_j}{x_i-x_j}
\]
现在我们得到
\[F(x)=\sum_{i=1}^{n}y_i\cdot \prod_{j=1,j\neq i}^{n}\frac{x-x_j}{x_i-x_j}
\]
考虑如何得到 \(F(x)\) 的系数,可以先 \(\mathcal{O}(n^2)\), 求得
\[G(x)=\prod_{j=1}^{n}{x-x_j}
\]
这个多项式的所有系数。
但是我们发现这并不满足 \(j\neq i\) 这个条件,因此要想办法对每个 \(i\) 除去 \(x-x_i\),这就要用到多项式除法。
又因为 \(x\) 的系数为 \(1\),因此单次除法可以做到 \(\mathcal{O}(n)\)。
这样,我们就可以对每个 \(i\),每次 \(\mathcal{O}(n)\) 得到分子,而分母的 \(x_i-x_j\) 只与 \(i\) 有关,因此需要每次重新算。
然后,对每个 \(i\),我们将 \(\mathcal{O}(n)\) 得到的分子乘上分母的逆元,再乘上 \(y_i\),就得到了 \(F(x)\) 的一部分(它也是个多项式)。
最后,再将所有 \(i\) 得到的系数值对应相加就得到了 \(F(x)\)。
总时间复杂度为 \(\mathcal{O}(n^2)\)。
代码:
#include<bits/stdc++.h>
using namespace std;
const int N=2e3+5,MOD=998244353;
int n,X[N],Y[N],fz1[N],fz2[N],tmp[N],xs[N];
int ksm(int x,int y){
int res=1;
while(y){
if(y&1)res=1ll*res*x%MOD;
x=1ll*x*x%MOD;
y>>=1;
}
return res;
}
inline int inc(int x,int y){return (x+y>=MOD)?(x+y-MOD):(x+y);}
inline int dec(int x,int y){return (x-y<0)?(x-y+MOD):(x-y);}
void pmul(int *A,int deg,int xi){//系数从下标1开始,deg表示多项式的度数
for(int i=deg+1;i>=1;i--)
tmp[i]=A[i],A[i]=A[i-1];
for(int i=1;i<=deg+1;i++)
A[i]=inc(A[i],1ll*tmp[i]*xi%MOD);
}
void pdiv(int *A,int *res,int deg,int xi){
for(int i=1;i<=deg+1;i++)tmp[i]=A[i];
for(int i=deg;i>=1;i--)
res[i]=tmp[i+1],tmp[i]=dec(tmp[i],1ll*tmp[i+1]*xi%MOD);
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d%d",&X[i],&Y[i]);
fz1[1]=1;
for(int i=1;i<=n;i++)
pmul(fz1,i,dec(0,X[i]));
for(int i=1;i<=n;i++){
int fm=1;
for(int j=1;j<=n;j++)
if(i!=j)fm=1ll*fm*dec(X[i],X[j])%MOD;
pdiv(fz1,fz2,n,dec(0,X[i]));
fm=1ll*Y[i]*ksm(fm,MOD-2)%MOD;
for(int j=1;j<=n;j++)
xs[j]=inc(xs[j],1ll*fm*fz2[j]%MOD);
}
for(int i=1;i<=n;i++)
printf("%d ",xs[i]);
return 0;
}

浙公网安备 33010602011771号