【学习笔记】拉格朗日插值优化dp
30分钟时间熟练掌握一种dp优化技巧,稳赚不配呀!
前言:拉格朗日插值作为一种多项式快速求值技巧,在一类dp函数为多项式函数的情况下,往往能够起到很好的降低复杂度的作用,而大多数时候我们所需要做的就是证明dp函数为一个多项式。
- 拉格朗日插值
 
用于解决给出不同的 $n+1$ 个点,要求求出一个经过这些点的多项式函数的一类问题,设第 $i$ 个点为 $x_i,y_i$。
$$\sum \limits_{i=1}^{n+1} y_i \prod \limits_{j \neq i} \frac{x-x_j}{x_i-x_j}$$
simple proof:
考虑将该多项式拆成 $n+1$ 个多项式的和,其中第 $i$ 个多项式在 $x_i$ 处为 $1$ ,而在其它给定的点处为 $0$ 。那么第 $i$ 个多项式 $f_i(x)=\prod \limits_{j \neq i} \frac{x-x_j}{x_i-x_j}$ ,那么显然 $f(x)=\sum \limits_{i=1}^{n+1} y_i\times f_i(x)$ ,于是可以得出上文的式子。
模板: P4781 【模板】拉格朗日插值
我们一般用数学归纳法证明一个函数是多项式,而这需要用到以下三个浅显但重要的结论
- 
一个 $n$ 次的多项式的前缀和是一个 $n+1$ 次多项式。
 - 
一个 $n$ 次的多项式的差分是一个 $n-1$ 次多项式。
 - 
一个 $n$ 次的多项式与一个 $m$ 次多项式相乘可以得到一个 $n+m$ 次多项式。
 
当然就如同有人有能肉眼看出函数凸性的异能,如果你有能肉眼看出函数是个多项式的异能,也可以直接不用证明。
感应出函数的多项式之后其实可以不用证次数,直接放尽可能多的点进去插就好了。
例题一
给出一棵树,要求给每个点赋一个不超过 $D$ 的权值,使得每个结点的权值不能超过它父亲的权值(如果它有的话) $n\le 3000,D\le 10^9$。
考虑 $f_{i,d}$ 表示点 $i$ 的权值为 $d$ ,以 $i$ 为根的子树的方案数,容易列出dp柿子。
$$f_{u,d} = \prod \limits_{v\in son_u} \sum \limits_{k=1}^d f_{v,k}$$
我们现在来证明 $f_{i,d}$ 是关于 $d$ 的 $siz_i-1$ 次多项式。
显然对于叶子节点 $f_{i,d}=1$ 符合条件。
观察柿子,最右边的部分是 $f_{v,d}$ 的一个前缀和,因此是一个 $siz_v-1+1=siz_v$ 次函数,之后将所有 $siz_v$ 次函数乘在一起,得到的 $f_{u,d}$ 就是一个 $siz_u-1$ 次函数。
最后的答案 $ans=\sum \limits_{i=0}^D f_{1,i}$ 就是一个 $n$ 次函数,因此我们对于 $d\in[0,n]$ 求出对应的dp值,之后再用这 $n+1$ 个点拉格朗日插值求出函数在 $D$ 处的值即可。
#include<bits/stdc++.h>
#define N 3010
#define mod 1000000007
using namespace std;
int read() {
	int res=0,f=1;char ch=getchar();
	while(!isdigit(ch)) f=ch=='-'?-1:1,ch=getchar();
	while(isdigit(ch)) res=res*10+ch-'0',ch=getchar();
	return f*res;
}
int n,d,a[N],f[N][N];
int cnt,head[N],to[N],nxt[N];
void insert(int u,int v) {
	cnt++;
	to[cnt]=v;
	nxt[cnt]=head[u];
	head[u]=cnt;
}
void dfs(int now) {
	for(int i=1; i<=n; i++) f[now][i]=1; 
	for(int i=head[now]; i; i=nxt[i])  {
		dfs(to[i]);
		for(int j=1; j<=n; j++) f[now][j]=1ll*f[now][j]*f[to[i]][j]%mod;
	}
	for(int i=1; i<=n; i++) f[now][i]=(f[now][i]+f[now][i-1])%mod;
}
int qpow(int a1,int a2) {
	int res=1;
	while(a2) {
		if(a2&1) res=1ll*res*a1%mod;
		a1=1ll*a1*a1%mod;
		a2>>=1;
	} return res;
}
int main() {
	n=read(),d=read();
	for(int i=2; i<=n; i++) {
		int fa=read();
		insert(fa,i); 
	} dfs(1);
//	printf("%d\n",f[1][2]);
	int ans=0;
	for(int i=0; i<=n; i++) {
		int s1=f[1][i],s2=1;
		for(int j=0; j<=n; j++) if(i!=j) s1=1ll*s1*(d-j+mod)%mod,s2=1ll*s2*(i-j+mod)%mod;
		ans=(ans+1ll*s1%mod*qpow(s2,mod-2)%mod)%mod;
	}
	printf("%d\n",ans); 
} 
第一道例题十分简单,帮助我们快速体验了证明dp函数是多项式函数的步骤。
同样简单的推导多项式性质的练习题:
显然如果题目的dp式子中有前缀和,差分,累乘,且只在邻近层中转移。那么有较大概率是一道拉插优化题。事实上,这题本身也是拉插优化的一个常见类型:带有一些偏序限制的计数问题。
下面我们将讲述两道与例题一类型相似的题,不过难度有较大的提升。
例题二
为序列上的每个位置 $i$ 赋一个为 $0$ 或在 $[a_i,b_i]$ 之间的权值,使得每个权值不为 $0$ 的点的权值都大于其前面所有点的权值。$n\le 500,1\le a_i \le b_i \le 10^9 $
虽然说用组合数做又短又快,但是拉插作为一种无脑的做法考场上应该更容易想到(?)。
首先普及一个小技巧,当拉格朗日插值所用到的点值是连续的时候,我们可以做到 $O(n)$ 插值。
假设 $n+1$ 个点的横坐标分别是 $0,1,2,...,n$ ,拉格朗日插值的式子变成这样:
$$\sum \limits_{i=0}^{n} \prod \limits_{j \neq i} y_i \frac{x-j}{i-j}$$
假设我们想要求得点值横坐标为 $k$ ,那么我们首先求出 $pre_i=\prod \limits_{j=0}^i (k-j),suf_i=\prod \limits_{j=i}^n (k-j)$ ,再预处理出阶乘 $fac_i$,原式变为:
$$\sum \limits_{i=1}^{n+1} (-1)^{n-i} y_i \frac{pre_{i-1}suf_{i+1}}{fac_i fac_{n-i}}$$
预处理与最后的计算都是 $O(n)$ 的。
线性插值&证多项式性质简单练习:
练习2: The Sum of the k-th Powers
回到正题,显然如果没有下头的区间限制,我们容易写出dp方程:设 $f_{i,j}$ 表示目前dp到了 $i$ ,当前最大值等于 $j$ 的方案数。显然
$$
f_{i,j}=F_{i-1,j}+\sum \limits_{k=0}^{j-1} f_{i-1,k}
$$
显然我们有 $f_{1,j}=1$ 这是一个零次多项式,而转移的方式是做前缀和,因此 $f_{i,j}$ 的次数是 $f_{i-1,j}$ 的次数 $+1$ ,即 $f_{i,j}$ 是个关于 $j$ 的 $i-1$ 次多项式。
预处理前缀和我们就能在 $O(n^2)$ 时间内解决这道题没有区间限制的版本。
现在我们考虑将区间改成左闭右开之后离散化,这样数轴就被我们划分成了 $O(n)$ 个段,而对于同一段内的数,一个点要么都能选,要么都不能选。那么在每一段中 $f_{i,j}$ 都是一个多项式,于是我们分段插值。
我们改设 $f_{i,j,k}$ 表示dp到了 $i$ ,最大值等于第 $j$ 段的第 $k$ 个数。设 $g_i$ 为第 $i$ 段的 $f$ 之和,$s_{i,j}$ 表示第 $i$ 段的前 $j$ 个 $f$ 之和,$ss_i$ 表示前 $i$ 段之和。为了方便转移,我们插值的时候插的是 $f$ 的前缀和 $s$ 。
时间复杂度 $O(n^3)$
#include<bits/stdc++.h>
#define N 2010 
#define mod 1000000007
using namespace std;
int read() {
	int res=0,f=1;char ch=getchar();
	while(!isdigit(ch)) f=ch=='-'?-1:1,ch=getchar();
	while(isdigit(ch)) res=(res<<3)+(res<<1)+ch-'0',ch=getchar();
	return f*res;
}
int qpow(int a1,int a2) {
	int res=1;
	while(a2) {
		if(a2&1) res=1ll*res*a1%mod;
		a1=1ll*a1*a1%mod,a2>>=1;
	}return res;
}
int n,l[N],r[N],f[N][N],mx,t[N],tn,ss[N],s[N][N],g[N],pre[N],suf[N],fac[N],inv[N];
int md(int a1) {return a1>=mod?a1-mod:a1;}
void add(int& a1,int a2) {a1=md(a1+a2);}
int Lag(int n,int k,int *y) {
	if(k<=n) return y[k];
	pre[0]=k,suf[n]=k-n;
	int res=0;
	for(int i=1; i<=n; i++) pre[i]=1ll*pre[i-1]*(k-i)%mod;
	for(int i=n-1; i; i--) suf[i]=1ll*suf[i+1]*(k-i)%mod;
	for(int i=0; i<=n; i++) 
		add(res,1ll*y[i]*inv[i]%mod*inv[n-i]%mod*(i>0?pre[i-1]:1)%mod*(i<n?suf[i+1]:1)%mod*((n-i)%2?mod-1:1)%mod);
	return res;
}
int main() {
	n=read(),fac[0]=inv[0]=1;
	for(int i=1; i<=n; i++) fac[i]=1ll*fac[i-1]*i%mod;
	inv[n]=qpow(fac[n],mod-2);
	for(int i=n-1; i; i--) inv[i]=1ll*inv[i+1]*(i+1)%mod;
	for(int i=1; i<=n; i++) 
		l[i]=read(),r[i]=read()+1,mx=max(mx,r[i]-1),
		t[++tn]=l[i],t[++tn]=r[i];
	sort(t+1,t+tn+1);
	tn=unique(t+1,t+tn+1)-t-1;
	for(int i=1; i<=n; i++)
		l[i]=lower_bound(t+1,t+tn+1,l[i])-t,
		r[i]=lower_bound(t+1,t+tn+1,r[i])-t;
	for(int i=0; i<tn; i++) ss[i]=1;
	for(int i=1; i<=n; i++) {
		for(int j=l[i]; j<r[i]; j++) {
			for(int k=0; k<=n; k++) add(f[j][k],md(ss[j-1]+(k>0?s[j][k-1]:0)));
			s[j][0]=f[j][0];
			for(int k=1; k<=n; k++) s[j][k]=md(s[j][k-1]+f[j][k]);
			g[j]=Lag(n,t[j+1]-t[j]-1,s[j]);
		} 
		for(int i=1; i<tn; i++) ss[i]=md(ss[i-1]+g[i]);
	}
	printf("%d",ss[tn-1]-1);
}
例题三
给出一棵树,求有多少方案在树上选出一条链,给链上每个点赋一个 $[l_i,r_i]$ 之间的权值,使得其最大值和最小值之差 $\le K$ ,此外再输出所有合法方案的权值和(一种方案的权值定义为选出的链的权值和)。
同样的,首先考虑没有区间的限制,我们枚举最小值 $L$ ,只要满足所有值都在 $[L,L+K]$ 之内就好了。
但是枚举最小值很麻烦,因此我们考虑差分。对于 $[L,R]$ 求出每个点权值在 $[L,R]$ 内的方案(实际是 $[L,R] \cap [l_i,r_i]$),之后减去每个点权值在 $[L+1,R]$ 之内的方案,就是最小值为 $L$ 的方案数了。
那么我们可以路径的LCA处统计路径,设四个数组 $mul_i,pmul_i,tot_i,ptot_i$ 分别表示以 $i$ 为LCA的路径条数,子树内以 $i$ 为一端的路径条数,以 $i$ 为LCA的路径的权值和,子树内以 $i$ 为一端的路径权值和,就可以简单dp了,复杂度为 $O(nV)$。
同样的,我们考虑将值域分成许多段,使得每一段内每个点的状态是固定的(都能选/都不能选)或是匀速变化的(+1/-1)。跟上一题不同的是,上一题中我们移动的是一个点,而这一次我们移动的是一个区间 $[L,L+K]$。因此我们将 $l_i,l_i-K,r_i,r_i-K$ 作为段的分界点,之后同样对每一段的前缀和插值即可。
分段插值时如果搞不清楚分多少段,那么在复杂度允许的范围下能分多少就分多少,保证每个段为多项式即可。
复杂度 $O(n^3)$
#include<bits/stdc++.h>
#define N 2010
#define mod 1000000007
#define int long long
int n,m=0,K,l[N],r[N];
using namespace std;
int read() {
	int res=0,f=1;char ch=getchar();
	while(!isdigit(ch)) f=ch=='-'?-1:1,ch=getchar();
	while(isdigit(ch)) res=res*10+ch-'0',ch=getchar();
	return f*res;
}
int cnt,head[N],to[N<<1],nxt[N<<1],L,R,a[N],b[N],c[N];
void insert(int u,int v) {
	cnt++;
	to[cnt]=v;
	nxt[cnt]=head[u];
	head[u]=cnt;
}
void add(int& a1,int a2) {a1=(a1+a2>=mod)?a1+a2-mod:a1+a2;}
int mul[N],pmul[N],tot[N],ptot[N],cnt1,cnt2;
int S(int x) {return 1ll*x*(x+1)>>1%mod;}
void dfs(int now,int fa) {
	int ll=max(L,l[now]),rr=min(R,r[now]);
	int len=ll>rr?0:rr-ll+1,sum=ll>rr?0:(S(rr)-S(ll-1)+mod)%mod;
	pmul[now]=mul[now]=len,ptot[now]=tot[now]=sum;
	for(int i=head[now]; i; i=nxt[i]) if(to[i]!=fa) {
		dfs(to[i],now);
		add(mul[now],1ll*pmul[now]*pmul[to[i]]%mod);
		add(tot[now],(1ll*ptot[to[i]]*pmul[now]%mod+1ll*pmul[to[i]]*ptot[now])%mod);
		add(pmul[now],1ll*pmul[to[i]]*len%mod);
		add(ptot[now],(1ll*ptot[to[i]]*len%mod+1ll*pmul[to[i]]*sum%mod)%mod);
	}
}
int pos[N];
void work(int f) {
	dfs(1,0);
	for(int i=1; i<=n; i++) cnt1+=mul[i]*f,cnt2+=tot[i]*f;
	cnt1=(cnt1+mod)%mod,cnt2=(cnt2+mod)%mod;
}
int qpow(int a1,int a2) {
	int res=1;
	while(a2) {
		if(a2&1) res=1ll*res*a1%mod;
		a1=1ll*a1*a1%mod;
		a2>>=1;
	} return res;
}
int Lag(int len,int x,int *a1,int *a2) {
	int res=0;
	for(int i=0; i<len; i++) {
		int s1=a2[i]%mod,s2=1;
		for(int j=0; j<len; j++) if(i!=j) s1=1ll*s1*(x-a1[j])%mod,s2=1ll*s2*(a1[i]-a1[j])%mod;
		add(res,1ll*s1*qpow(s2,mod-2)%mod);
	} return res;
}
signed main()  {
	n=read(),K=read();
	for(int i=1; i<=n; i++) {
		l[i]=read(),r[i]=read();
		pos[++m]=l[i],pos[++m]=max(l[i]-K,0ll),pos[0]=max(pos[0],r[i]+1);
		pos[++m]=r[i],pos[++m]=max(r[i]-K,0ll);
	}
	sort(pos,pos+m+1),m=unique(pos,pos+m+1)-pos-1;
	for(int i=1; i<n; i++) {
		int u=read(),v=read();
		insert(u,v);
		insert(v,u);
	}
	for(int i=0,j; i<m; i++) {
		L=pos[i],R=pos[i]+K;
		for(j=0; j<n+2; j++,R++) {
			if(pos[i]+j==pos[i+1]) break;
			work(1),++L,work(-1);
			a[j]=pos[i]+j,b[j]=cnt1,c[j]=cnt2;
		}
		if(pos[i]+j<pos[i+1]) cnt1=Lag(j,pos[i+1]-1,a,b),cnt2=Lag(j,pos[i+1]-1,a,c);
	}
	printf("%d\n%d",cnt1,cnt2);
}
稍有不同的分段插值:
进阶题
练习4: P5469 [NOI2019] 机器人
这题的主体其实已经是dp不是拉插了,可以采用拉插之外的方法维护多项式,拉插的话依旧是经典的分段插值。
到这里就结束了,完成上面的所有题目,那么之后你在考场上见到拉插优化dp的时候应该有大概率可以做出来。
$\color{#00ffff}\it\huge\{&}$
_$\color{#ff0000}\it\huge\{Thanks\ for\ reading {\color{#dddddd}\it\small\{}}}$
                    
                
                
            
        
浙公网安备 33010602011771号