拉格朗日插值
拉格朗日插值
首先,我们知道给出 \(n+1\) 个点 \((x_i,y_i)\) 可以唯一确定一个 \(n\) 次多项式。
问题:给出 \(n+1\) 个点,求出这个 \(n\) 次多项式在 \(k\) 处的取值,即 \(f(k)\)。
首先,我们可以列出 \((n+1)\) 个方程解出这个多项式的系数,但是这样是 \(O(n^3)\) 的。有没有更给力的?
实际上,我们有拉格朗日插值,可以在 \(O(n^2)\) 的复杂度内解决这个问题。
构造多项式:
\[g(x)=\sum\limits_{i=0}^n y_i \prod\limits_{j\ne i}\dfrac{k-x_i}{x_i-x_j}
\]
我们说明这个多项式其实就是 \(f(x)\)。
首先,\(g(x)\) 显然是一个 \(n\) 次多项式,我们只需说明 \(g(x)\) 与 \(f(x)\) 在 \(n\) 个位置处取值相同即可。
\[\forall x_k,i\ne k,k,\prod \limits_{j\ne i}\dfrac{x_k-x_i}{x_i-x_j}=(\prod \limits_{j\ne i \land j\ne k}\dfrac{x_k-x_i}{x_i-x_j})\times \dfrac{x_k-x_k}{x_k-x_j}=0
\]
\[\forall x_k,i= k,\prod \limits_{j\ne i}\dfrac{x_k-x_i}{x_i-x_j}=\prod \limits_{j\ne i}\dfrac{x_k-x_i}{x_k-x_j}=1
\]
因此
\[\begin{aligned} g(x_k)&=\sum\limits_{i=0}^n y_i \prod\limits_{j\ne i}\dfrac{k-x_i}{x_i-x_j} \\&= \sum\limits_{i=0}^{k-1} y_i\times 0 + \sum\limits_{i=k+1}^{n} y_i\times 0 +y_k\times 1 \\&=y_k\end{aligned}
\]
\(g(x)\) 与 \(f(x)\) 在 \(x_0,x_1\cdots x_n\) 处取值相等,因此,\(g(x)=f(x)\)。
这样,我们就可以在 \(O(n)\) 的时间复杂度内求出 \(f(k)\) 的值了。
注意:在写的时候不要求 \(n^2\) 次逆元,而是处理出来分母后再总共求 \(n\) 次,这样求逆元的总复杂度就是 \(O(n\log mod)\),不会成为瓶颈。
在 \(x\) 取值连续的时候,复杂度还可以优化。我们不妨设 \(x_i=i\),因为若 \(x_i\ne i\),由于满足取值连续,我们也可将图像整体平移。
\(g(x)=\sum\limits_{i=0}^n y_i \prod\limits_{j\ne i}\dfrac{k-x_i}{x_i-x_j}\)。
我们定义:
\[fac_i=\prod\limits_{j=1}^i j\\
pre_i=\prod\limits_{j=0}^i k-j\\suf_i=\prod\limits_{j=i}^n k-j
\]
那么,我们可以将 \(g(x)\) 改写为:
\[\begin{aligned} g(x)&=\sum\limits_{i=0}^n y_i \prod\limits_{j\ne i}\dfrac{k-i}{i-j}\\ &= \sum\limits_{i=0}^n y_i \dfrac{pre_{i-1}\times suf_{i+1}}{fac_{i-1}\times fac_{n-i}\times(-1)^{n-i}} \end{aligned}
\]
预处理 \(fac,pre,suf\) 即可 \(O(n)\) 计算。
贴上代码:
#include<bits/stdc++.h>
#define mod 998244353
using namespace std;
int n;
long long k,x[2010],y[2010];
long long fastpow(long long x,int y){
long long ans=1;
while(y){
if(y&1) ans=(ans*x)%mod;
x=(x*x)%mod;
y>>=1;
}
return ans;
}
long long F(long long k){
long long ans=0,a,b;
for(int i=0;i<=n;i++){
a=1,b=1;
for(int j=0;j<=n;++j){
if(i==j) continue;
a=a*(k-x[j])%mod;
b=b*(x[i]-x[j])%mod;
}
a=y[i]*a%mod*fastpow(b,mod-2);
ans=(ans+a)%mod;
}
return (ans+mod)%mod;
}
int main()
{
scanf("%d%lld",&n,&k);--n;
for(int i=0;i<=n;i++) scanf("%lld%lld",&x[i],&y[i]);
printf("%lld",F(k));
return 0;
}