拉格朗日插值

引入

\(\;\)
假设现在我们得到了一个\(n\)次多项式\(f(x)=a_0+a_1x+a_2x^2+\cdots+a_nx^n\)\(n+1\)个条件形如\((x_i,y_i)\),表示当\(x=x_i\)时,多项式的值为\(y_i\)
求:\(a_0,a_1,\cdots,a_n\)
对于这类问题,我们当然可以根据\((x_i,y_i)\)写出\(n+1\)\(n+1\)次方程组成的线性方程组,然后再\(O(n^3)\)的时间内用高斯消元求解
但对于\(n\leq 2000\)这样的范围,高斯消元就不适用了。
接下来介绍的拉格朗日插值法,时间复杂度是\(O(n^2)\)
\(\;\)

拉格朗日插值

\(\;\)
我们考虑构造出\(n+1\)个多项式,满足第\(i\)个多项式只有当自变量取值为\(x_i\)时,其值为1,否则为0
那么对于第\(i\)个多项式,其形式即为:

\[fi(k)=\prod_{i\neq j} \frac{k-x_j}{x_i-x_j} \]

显然,上式是满足条件的。
那么对于原多项式,显然:

\[f(k)=\sum_{i=0}^n y_i fi(k) \]

如果题目只是要求这个多项式在给定\(k\)下的函数值,显然可以\(O(n^2)\)来解决。
但如果要求每一项系数,仍然是\(O(n^3)\)
\(\;\)

特殊情况

\(\;\)
若给定的\(x_i\)是连续的数,即:\(x_i=i\),我们来看这个东西有什么更好的性质
\(fi(k)\)可以变为\(\prod_{i\neq j} \frac{k-j}{i-j}\)
我们把整个柿子抄一遍

\[f(k)=\sum_{i=0}^n y_i \prod_{i\neq j} \frac{k-j}{i-j} \]

\(h(i)=\prod_{j=0}^i (k-j),r(i)=\prod_{j=i}^n (k-j), fac(i)=i!\),那么:

\[f(k)=\sum_{i=0}^n y_i \frac{h(i-1)r(i+1)}{fac(i)fac(n-i)(-1)^{n-i}} \]

那么我们预处理好\(h,r,fac\)\(f(k)\)就可以\(O(n)\)的算出来了
但是若要求表达式仍是O(n^3)的
\(\;\)

优化

\(\;\)
其实也不算是优化,是为了解决另一种更繁琐的情况,若有时候要减少一个或加入一个插值点,即:\((x_i,y_i)\)
按原来的式子还必须重新算一遍,如何优化呢?
观察上面的式子:
\(f(k)=\sum_{i=0}^n y_i \prod_{i\neq j} \frac{k-x_j}{x_i-x_j}\)
我们发现\(k-x_j\)这里是与\(i\)无关的,提到前面。
\(g(k)=\prod_{i=0}^n (k-x_i)\)
于是原式就变成了:
\(f(k)=g(k) \sum_{i=0}^n \frac{y_i}{k-x_i} \prod_{i\neq j} \frac{1}{x_i-x_j}\)
\(t(i)= \prod_{i\neq j} \frac{1}{x_i-x_j}\)
\(f(k)=g(k) \sum_{i=0}^n \frac{y_it_i}{k-x_i}\)
我们发现,\(t(1),t(2),\cdots,t(n)\)是可以\(O(n^2)\)预处理的,其余用\(O(n)\)时间即可解决
那么如果我们要加入一个插值点\((x_{n+1},y_{n+1})\),显然只需要把所有的\(t(i)\)除以\(x_i-x_{n+1}\)
这样修改的复杂度是\(O(n)\)的,然后我们再用\(O(n)\)的时间求值即可

Code

\(\;\)
\(f(k)\)的值。
代码用的是那个支持加入插值点方法(其实第一种也可以做)

#include <bits/stdc++.h>

const int N = 2010, mod = 998244353;
int n, k, g = 1, t[N], x[N], y[N];
int power(int a, int b) {
	int ans = 1;
	while(b) {
		if(b & 1) ans = 1ll * ans * a % mod;
		a = 1ll * a * a % mod;
		b >>= 1; 
	}
	return ans;
} 
int main() {
	scanf("%d%d", &n, &k); 
	for(int i=0;i<n;i++) scanf("%d%d", &x[i], &y[i]);
	for(int i=0;i<n;i++) g = 1ll * g * (k - x[i] + mod) % mod;
	for(int i=0;i<n;i++) {
		int now = 1;
		for(int j=0;j<n;j++) {
			if(j == i) continue;
			now = 1ll * now * (x[i] - x[j] + mod) % mod;
		}
		t[i] = power(now, mod - 2);
	}
	int ans = 0;
	for(int i=0;i<n;i++) {
		int tmp = 1ll * y[i] * t[i] % mod * power(k - x[i] + mod, mod - 2) % mod;
		ans = (ans + tmp) % mod;
	}
	printf("%d", 1ll * ans * g % mod);
	return 0;
}
posted @ 2021-02-24 20:47  czytysnow  阅读(22)  评论(0编辑  收藏