拉格朗日插值

这个东西应该在很久之前就要学的结果被鸽到了现在。

我是鸽德

拉格朗日插值

拉格朗日插值解决的是一类给定多项式的点值表示让你求另一个点的函数值的问题。

先来思考这个引子:给定 \(n\) 个点对 \((x_i,y_i)\)\(k\),保证 \(\forall i\not=j,x_i\not=x_j\),求 \(f(k)\)

有一个显然的线性代数写法,建立一个 \(n\times n\) 的矩阵,然后解方程即可。

时间复杂度 \(\mathcal O(n^3)\)。这个时间复杂度是比较差的。

接下来进入正题。

我们直接给出拉格朗日多项式的式子

\[f(x)=\sum_{i=1}^{n}y_i\prod_{j\not=i}\frac{x-x_j}{x_i-x_j} \]

这里只给出直观解释:考虑一个点 \((x_p,y_p)\),计算 \(f(x_p)\),当 \(i\not=p\) 时,\(j\) 可以等于 \(p\),于是出现了 \(x-x_p\),答案为 \(0\);当 \(i=p\) 时,分子和分母相等,因此有 \(f(x_p)=y_p\)

据此,我们只要能找到一个多项式的点值表示,我们便可以求出这个多项式上任意一点的值。

代入计算,时间复杂度 \(\mathcal O(n^2)\)

接下来是一个经典应用:求 \(\sum\limits_{i=1}^{n}i^k\),对 \(10^9+7\) 取模。

如果要用拉格朗日插值,就要明确我们要求的东西是多项式,问题转化为如何证明 \(f(x)=\sum\limits_{i=1}^{x}i^k\) 是一个项数较少的多项式。

\(\Delta f(x)\)\(f(x)\) 的差分数组,\(\Delta^2f(x)\)\(f(x)\) 的二阶差分数组,即 \(\Delta^2f(x) = \Delta(\Delta f(x))\)

如果我们令 \(\Delta^0f(x)=f(x)\),我们可以递归地得到如下定义 \(p\in \mathbb N^*,\Delta^pf(x)=\Delta(\Delta^{p-1}f(x))\),称其为 \(f(x)\)\(p\) 阶差分。

如果 \(\Delta^pf(x)\) 是一个常数序列,则 \(f(x)\)\(p\) 阶等差数列。

引理:如果数列 \(\{a_n\}\) 是一个 \(p\) 阶等差数列,那么 \(a_n\) 是关于 \(n\) 的一个 \(p\) 次多项式,即 \(a_n=k_pn^p+k_{p-1}n^{p-1}+\dots+k_0\)

证明:假设 \(a\) 数列的通项是一个关于 \(n\) 的一个 \(p\) 次多项式,那么 \(a_x = \sum\limits_{i=0}^{p}k_ix^i\),据此可得到

\[\begin{align*} \Delta a(x) & = a(x+1)-a(x) \\ ~ & = \sum_{i=0}^{p}k_i(x+1)^i - \sum_{i=0}^{p}k_ix^i \end{align*} \]

注意到两项的 最高次项系数是相等的,相减后消去,因此,\(\Delta a(x)\) 的次数是 \(a(x)\) 的次数减一。\(p\) 次差分后,得到 \(0\) 次式,即常数序列。

证毕。

本题,可以定义数列 \(\{a_n\}\)

\[\sum_{i=1}^{1}i^k,\sum_{i=1}^{2}i^k,\sum_{i=1}^{3}i^k,\dots,\sum_{i=1}^{n}i^k \]

两两做差,发现序列变成了 \(1^k,2^k,\dots,n^k\)(这里在第一项前补了 \(0\))。

通项为 \(a_x=x^k\),是 \(k\) 次式,根据引理,\(\{a_n\}\) 应该是一个 \(k+1\) 次式。

朴素拉格朗日插值 \(\mathcal O(k^2)\),是否有更优的做法?

注意到我们可以代入的点可以自选,所以我们可以寻找一批有特殊性质的点。我们尝试用 \(1\)\(k+2\) 代入。

\(x_i=i,y_i=\sum\limits_{j=1}^{i}j^k\)

\[f(n)=\sum_{i=1}^{k+2}y_i\prod_{j\not=i}\frac{n-x_j}{x_i-x_j} \]

注意到,分母是 \(\prod_{j=1}^{i-1}\prod_{j=i-k-2}^{-1}\),可以预处理,分子是 \(\prod_{j=n-i-1}^{n-1}\prod_{j=n-k-2}^{n-i+1}\),可以预处理出前缀后缀乘积。

预处理 \(\mathcal O(k)\)

如果你认为上面这道题很简单,那就快来做一下例题吧!

[省选联考 2022] 填树

绝了,没看sol脑子一片空白。

感觉在赛场上10pts遗憾离场,只会一个不知道时间复杂度多少的暴力。

对于区间题,有一个想法是枚举赋值区间,但区间交容易算重,因此我们加入一定限制条件:至少有一个节点取最小值。

不加入这个限制非常好做,每个点的范围都被固定了,随便乘一下,加入限制后只需将求出的答案减 \([l+1,l+k]\) 的答案就行,这样似乎可以 dp,在 \([l,r]\) 范围内,设 \(f(x)\) 表示满足条件的树的个数,\(g(x)\) 表示从根节点到 \(x\) 点满足条件的路径权值之和,可以写出如下式子:

\[f_u=f_{fa}\cdot len_u \\ \]

\[g_u = f_u \cdot g_{fa} + \sum_{i=l}^{r}i\cdot f_{fa} \]

其中,\(len_u\) 表示 \(u\) 这个点可选的值的个数。考虑 \(u\) 点的贡献分为两部分:

  • 从父亲部分继承的贡献,这一部分是 \(f_{u}\cdot g_{fa}\)
  • \(u\) 点做出的贡献,每个值能做出 \(f_{fa}\) 的贡献。

枚举取值范围 \(\mathcal O(w)\),dfs 一次 \(\mathcal O(n)\),但由于需要枚举根节点,还需要 \(\mathcal O(n)\)。当然,你聪明的发现了这只是个在树上找路径的 dp,所以你消去了一个 \(n\),但是 \(w \le 10^9\) 断绝了一切念想。

优化做法 这里是关于 $\mathcal O(nw)$ 做法的介绍。

降低复杂度的关键是不去枚举根节点,这意味着我们需要用某种方式快速合并子树信息。

在枚举节点 \(u\) 时,合法的路径会有一段在祖先上,一段在子树内。dp 时要考虑合并的顺序。

假设 \(f'_x\) 表示 \(x\) 已经遍历的子树内的路径数,\(g'_x\) 表示 \(x\) 已经遍历的子树内的权值和。

可以改写 dp 式子:

\[f_u=f'_v \cdot f'_u \]

\[g_u=g'_v \cdot f'_u+f'_v \cdot g'_u \]

其中,\(f'_u\) 初始化为 \(len_u\)\(g'_u\) 初始化 \((l+r)(r-l+1)/2\)

合并子树信息

\[f'_u=f'_u + f'_v \cdot len_u \]

\[g'_u=g'_u + g'_v \cdot len_u + f'_v \cdot \big((l+r)(r-l+1)/2\big) \]

注意到应该先统计答案再合并子树,不然会算重。


我们对每一个区间进行 dp 转移的时候,我们发现区间的交和 \(l\) 是有关的。\(len_u\) 转换成了一个和 \(l\) 有关的一次式子。因此 \(f(u),g(u)\) 就是关于 \(l\) 的多项式。

我们要求的是前缀和,即 \(f(1,k + 1) + f(2, k + 2) + \ldots + f(mx - K,mx),mx=\max r_i\),同时减去容斥部分即可。

const int P = 1e9 + 7, N = 205, M = 1000;
int n, k, l[N], r[N], lsh[M], tot, L, R, x[M], Z[M], C[M];
vector<int> e[N];
ll ans1, ans2, f[N], g[N], fc[N], gc[N]; // f, g, f', g'

il void dfs(int u, int fa) {
	int x = max(l[u], L), y = min(r[u], R);
	if (x > y) x = 1, y = 0;
	int len = y - x + 1, sum = (1ll * (x + y) * (y - x + 1) / 2) % P;
	f[u] = fc[u] = len, g[u] = gc[u] = sum;
	for (int v : e[u]) {
		if (v == fa) continue;
		dfs(v, u);
		f[u] += fc[u] * fc[v] % P;
		g[u] += (fc[u] * gc[v] % P + fc[v] * gc[u] % P) % P;
		fc[u] = (fc[u] + fc[v] * len % P) % P;
		gc[u] = (gc[u] + gc[v] * len % P + fc[v] * sum % P) % P;
	}
}

il void calc(int opt) {
	dfs(1, 0);
	for (int i = 1; i <= n; ++i) ans1 += f[i] * opt, ans2 += g[i] * opt;
	ans1 = (ans1 % P + P) % P, ans2 = (ans2 % P + P) % P;
}

il int qpow(int x, int y) {
	int ret = 1;
	for (; y; y >>= 1, x = 1ll * x * x % P) if (y & 1) ret = 1ll * ret * x % P;
	return ret;
}

il int lagrange(int p, int x, int* X, int* Y) { // p 次多项式
	int ret = 0;
	for (int i = 0; i < p; ++i) {
		int fz = 1, fm = 1;
		for (int j = 0; j < p; ++j) 
			if (i != j) fz = 1ll * fz * (x - X[j]) % P, fm = 1ll * fm * (X[i] - X[j]) % P;
		ret = (ret + 1ll * fz * qpow(fm, P - 2) % P * Y[i] % P) % P;
	}
	return ret;
}

int main() {
	// freopen("tree.in", "r", stdin);
	// freopen("tree.out", "w", stdout);
	read(n), read(k);
	for (int i = 1; i <= n; ++i) { 
		read(l[i]), read(r[i]);
		lsh[++tot] = l[i], lsh[++tot] = max(0, l[i] - k);
		lsh[++tot] = r[i], lsh[++tot] = max(0, r[i] - k);
		lsh[0] = max(lsh[0], r[i] + 1);
	}
	sort(lsh, lsh + tot + 1); tot = unique(lsh, lsh + tot + 1) - lsh;
	for (int i = 1; i < n; ++i) {
		int u = read(), v = read();
		// assert(u <= n && v <= n);
		e[u].eb(v), e[v].eb(u);
	}
	for (int i = 0, j; i < tot; ++i) {
		L = lsh[i], R = lsh[i] + k;
		for (j = 0; j < n + 2; ++j, ++R) {
			if (lsh[i] + j == lsh[i + 1]) break;
			calc(1); ++L; calc(-1);
			x[j] = lsh[i] + j, Z[j] = ans1, C[j] = ans2;
		}
		if (lsh[i] + j < lsh[i + 1]){
			ans1 = lagrange(j, lsh[i + 1] - 1, x, Z);
			ans2 = lagrange(j, lsh[i + 1] - 1, x, C);
		}
	}
	write(ans1), write(ans2);
	return 0;
}
posted @ 2023-03-28 10:46  MisterRabbit  阅读(64)  评论(0)    收藏  举报