[浅谈] 拉格朗日插值

Introduce

给定 \(n\) 个点,那么可以确定一个不超过 \(n-1\) 项的多项式函数值。我们可以使用高斯消元,但是 \(O(n^3)\) 的时间复杂度和精度误差难以接受。

Principle

我们考虑构造函数 \(fi\) ,满足其在 \(x=x_i\) 时函数值为 \(1\) ,在 \(x=x_j(j\neq i,j\in[1,n])\)\(0\),这很好构造: $fi(x)=\prod_{j\neq i}\frac{x-x_j}{x_i-x_j} $ 。

然后再构造 \(f(x)=\sum_{i=1}^{n}y_ifi(x)\) ,那么这个函数就经过这 \(n\) 个点。

\(f(x)=\sum_{i=1}^{n}y_i\prod_{j\neq i}\frac{x-x_j}{x_i-x_j}\)

\(\color{purple}\text{P4781 【模板】拉格朗日插值}\)

给定 \(n\) 点确定多项式,求 \(f(k)\) ,直接代入上面的公式。

点击查看代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=2e3+110,mod=998244353;
int read(){
	int x=0,f=1;char c=getchar();
	while(c>'9' || c<'0'){if(c=='-')f=-1;c=getchar();}
	while(c>='0' && c<='9'){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
	return x*f;
}
int n,k,x[N],y[N],ans;
int ksm(int x,int y){
	int sum=1;
	while(y){
		if(y&1)sum=(sum*x)%mod;
		y>>=1;
		x=(x*x)%mod;
	}
	return sum%mod;
}
signed main(){
	n=read(),k=read();
	for(int i=1;i<=n;i++)x[i]=read(),y[i]=read();
	for(int i=1;i<=n;i++){
		int tmp=y[i];
		for(int j=1;j<=n;j++)if(j!=i)tmp=tmp%mod*(k-x[j])%mod*ksm(x[i]-x[j],mod-2)%mod;
		ans=(ans+tmp)%mod;
	}
	printf("%lld\n",(ans+mod)%mod);
	return 0;
}

\(x\) 连续时优化

如果 \(x_i=i\)
\(f(x)=\sum_{i=1}^{n}y_i\prod_{j\neq i}\frac{x-x_j}{x_i-x_j}=\sum_{i=1}^{n}y_i\prod_{j\neq i}\frac{x-j}{i-j}=\sum_{i=1}^{n}y_i\frac{pre[i-1]suf[i+1]}{fac[i-1]fac[n-i](-1)^{n-i}}\)

此时 \(pre[i]=\prod_{j=1}^ix-j\)
\(suf[i]=\prod_{j=i}^nx-j\)

时间复杂度 \(O(n)\)

可能有锅
#include<bits/stdc++.h>
#define int long long

using namespace std;
const int N=2e6+110,mod=998244353;
int read(){
	int x=0,f=1;char c=getchar();
	while(c>'9' || c<'0'){if(c=='-')f=-1;c=getchar();}
	while(c>='0' && c<='9'){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
	return x*f;
}
int n,k,fac[N],pre[N],suf[N],y[N],ans;
int ksm(int x,int y){
	int sum=1;
	while(y){
		if(y&1)sum=(sum*x)%mod;
		y>>=1;
		x=(x*x)%mod;
	}
	return sum%mod;
}
void pre_opt(){
	fac[0]=pre[0]=suf[n+1]=1;
	for(int i=1;i<=n;i++)fac[i]=fac[i-1]*i%mod;
	for(int i=1;i<=n;i++)pre[i]=(pre[i-1]*(k-i))%mod;//,cout<<pre[i]<<" ";cout<<endl;
	for(int i=n;i>=1;i--)suf[i]=(suf[i+1]*(k-i))%mod;//,cout<<suf[i]<<" ";cout<<endl;
	return;
}
signed main(){
	n=read(),k=read()%mod;
	for(int i=1;i<=n;i++)y[i]=read()%mod;
	pre_opt();
	for(int i=1;i<=n;i++){
		int tmp=y[i];
		tmp=tmp*pre[i-1]%mod*suf[i+1]%mod;
		tmp=tmp*ksm(fac[i-1]*fac[n-i]%mod*ksm(-1,n-i)%mod,mod-2)%mod;
		ans=(ans+tmp)%mod;
	}
	printf("%lld\n",(ans+mod)%mod);
	return 0;
}

\(\color{purple}\text{The Sum of the k-th Powers}\)

\(\sum_{i=1}^ni^k\)\(n\leq 10^9,k\leq 10^6\)
把这个式子展开其实是个 \(k+1\) 项式。那么我们把它看成一个函数 \(f(n)\) ,带入 \(k+2\) 个点即可。选择连续的点可以做到 \(O(n \log 10^9)\) 。(逆元带来的 \(\log\) )。

最后提醒一句:分数取模不要漏模,多模,减法时注意 \(+mod\)

点击查看代码
#include<bits/stdc++.h>
#define int long long

using namespace std;
const int N=2e6+110,mod=1e9+7;
int read(){
	int x=0,f=1;char c=getchar();
	while(c>'9' || c<'0'){if(c=='-')f=-1;c=getchar();}
	while(c>='0' && c<='9'){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
	return x*f;
}
int n,k,fac[N],pre[N],suf[N],y,ans;
int ksm(int x,int y){
	int sum=1;
	while(y){
		if(y&1)sum=(sum*x)%mod;
		y>>=1;
		x=(x*x)%mod;
	}
	return sum%mod;
}
void pre_opt(){
	fac[0]=pre[0]=suf[k+3]=1;
	for(int i=1;i<=k+2;i++)fac[i]=fac[i-1]*i%mod;
	for(int i=1;i<=k+2;i++)pre[i]=(pre[i-1]*(n-i))%mod;//,cout<<pre[i]<<" ";cout<<endl;
	for(int i=k+2;i>=1;i--)suf[i]=(suf[i+1]*(n-i))%mod;//,cout<<suf[i]<<" ";cout<<endl;
	return;
}
signed main(){
	n=read(),k=read();
	pre_opt();
	for(int i=1;i<=k+2;i++){
		y=(y+ksm(i,k))%mod;int tmp=y;
		int a=y*pre[i-1]%mod*suf[i+1]%mod;
		int b=fac[i-1]*fac[m+3-i]%mod*ksm(-1,m+3-i)%mod;
		ans=(ans+a*ksm(b,mod-2)%mod+mod)%mod;
	}
	printf("%lld\n",(ans+mod)%mod);
	return 0;
}
posted @ 2023-06-07 16:00  FJOI  阅读(24)  评论(0)    收藏  举报