Loading

拉格朗日插值如何插出系数

好久之前在 cmd's blog 看到过,这次做题遇上了,就学了一下,其实挺 easy 的。

众所周知其实是我不会证 \(n\) 个点 \((x_i,y_i)\) 可以唯一确定一个次数为 \(n-1\) 的多项式,拉格朗日插值给出了一种构造:

\[f(z)=\sum_{i=1}^{n} \dfrac{y_i\prod_{j\not=i}(z-x_j)}{\prod_{j\not=i}(x_i-x_j)} \]

首先提出常数部分:

\[a_i=\dfrac{y_i}{\prod_{j\not=i}(x_i-x_j)} \]

可以 \(O(n^2)\) 搞出每一个 \(a_i\)

然后求一个多项式 \(g(z)=\prod_{i=1}^{n} (z-x_i)\)

可以发现

\[f(z)=\sum_{i=1}^{n}a_i\dfrac{g(z)}{z-x_i} \]

考虑如何快速搞出后面那个 \(\dfrac{g(z)}{z-x_i}\)

\(h(z)=\dfrac{g(z)}{z-c}\)

可以得到 \((z-c)h(z)=g(z)\)。两边提取系数得到

\[[z^{i-1}]h-c[z^i]h=[z^i]g\\ [z^i]h=\dfrac{[z^i]g-[z^{i-1}]h}{-c} \]

递推即可。

给出 模板题 通过代码:

#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    return f?x:-x;
}
#define mod 998244353
inline int qpow(int n, int k) {
	int res = 1;
	for(; k; k >>= 1, n = 1ll * n * n % mod)
		if(k & 1) res = 1ll * n * res % mod;
	return res;
}
vector <int> lagrange(const vector <int> &x, const vector <int> &y) {
	assert(x.size() == y.size());
	int n = x.size();
	vector <int> a(n, 0), b(n + 1, 0), c(n + 1, 0), f(n, 0);
	for(int i = 0; i < n; ++i) {
		int A = 1;
		for(int j = 0; j < n; ++j) if(i != j)
			A = 1ll * A * (x[i] - x[j] + mod) % mod;
		a[i] = 1ll * qpow(A, mod - 2) * y[i] % mod;
	}
	b[0] = 1;
	for(int i = 0; i < n; ++i) {
		for(int j = i + 1; j >= 1; --j)
			b[j] = (1ll * b[j] * (mod - x[i]) + b[j - 1]) % mod;
		b[0] = 1ll * b[0] * (mod - x[i]) % mod;
	}
	for(int i = 0; i < n; ++i) {
		int iv = qpow(mod - x[i], mod - 2);
		if(!iv) {
			for(int j = 0; j < n; ++j) c[j] = b[j + 1];
		} else {
			c[0] = 1ll * b[0] * iv % mod;
			for(int j = 1; j <= n; ++j)
				c[j] = 1ll * (b[j] + mod - c[j - 1]) * iv % mod;
		}
		for(int j = 0; j < n; ++j)
			f[j] = (f[j] + 1ll * a[i] * c[j] % mod) % mod;
	}
	return f;
}
inline int calc(const vector <int> &f, int x) {
	int res = 0;
	for(int i = f.size() - 1; i >= 0; --i) res = (1ll * res * x + f[i]) % mod;
	return res;
}
signed main() {
	int n = read(), k = read();
	vector <int> x(n), y(n);
	for(int i = 0; i < n; ++i) x[i] = read(), y[i] = read();
	vector <int> f = lagrange(x, y);
	cout << calc(f, k) << '\n';
}

posted @ 2021-03-31 20:57  zzctommy  阅读(1122)  评论(0编辑  收藏  举报