斜率优化

斜率优化是非常常见的一种DP优化方式。其代码简短、维护方法众多。其主要思想是通过维护一个凸包来辅助转移,快速找到最优的决策点。
顾名思义,“斜率”肯定和线段有联系,因此斜率优化其中一种比较无脑的方法是写李超线段树来维护凸包。其不需要管单调性等类似的限制,直接无脑写即可。
李超线段树
注意斜率优化中的李超线段树的复杂度是 \(O(n\log n)\) 的。原因后面会说。

P3195 [HNOI2008] 玩具装箱

朴素的DP方程:

\[f_i=\min_{j\in[0,i-1]}(f_j+(i-(j+1)+pc_i-pc_j-L)^2) \]

其中 \(pc_k=\sum_{i=1}^kc_i\)
当然我们需要将这个平方内的式子变得简单一点,不然项太多了。我们让 \(L'=L+1\)\(pc_k'=pc_k+k\)。下面我们还是将 \(L',pc_k'\) 写为 \(L,pc_k\)
这样式子变为

\[f_i=\min_{j\in[0,i-1]}((pc_i-pc_j-L)^2) \]

然后暴力拆开有(这里将取 \(\min\) 省略)

\[\begin{aligned} f_i = \underbrace{-2pc_j}_{\text{k}} \underbrace{\times pc_i}_{\text{x}} + \underbrace{f_j + pc_j^2 + 2pc_j \times L}_{\text{b}} + \underbrace{pc_i^2 - 2pc_i \times L + L^2}_{\text{无关项}} \end{aligned} \]

为什么要化为如此奇怪的形式?因为用李超线段树维护的斜率优化的精髓就是将转移方程式化为

\[f_i=\text{只与j有关的项}\times \text{只与i有关的项}+\text{只与j有关的项}+\text{只与i有关或者与任意条件都无关的无关项} \]

的形式。
可以发现这个式子与一次函数式子高度类似。我们将其看作一个关于 \(j\) 的函数加上一些无关项的形式,而这个一次函数的组成部分就与上面的例子中标注的 \(k,x,b\) 一一对应。
这时我们发现对于一个想要求其 DP 值的 \(i\),如果我们要从一些已经求得其 DP 值的 \(j\) 中转移,我们只需要将对应的 \(x\) 值代入进去既可以找到最优的转移点。

当然一个一个的带入与暴力无疑。但是这个时候我们就可以将其理解为我们要求在某一个特定的 \(x\) 值上所有直线的最低点。(以这道题为例就是 \(x=pc_i\) 的时候这些关于 \(j\) 的直线的最小值加上无关项即为 \(f_i\)

那这个时候就可以用李超线段树维护了。求出一个位置的 DP 值后就可以直接将这个 DP 值所对应的直线插入进线段树来方便之后的点转移。
注意到这里的插入的所有所谓线段都是全局插入,并没有将区间分为 \(\log n\) 个小区间的过程,因此单次查询与插入都是 \(\log n\) 的。

因此总复杂度为 \(O(n\log n)\)

code

需要注意的是由于刚开始线段树内是没有线段的,因此我们需要将 \(seg_{0,b}\) 设为 \(inf\)。因为一开始线段树内的初值都为 0,这样就可以保证如果没有线段则不会访问到奇怪的地方去。

同时,这种写法会有一个问题:没有办法从 \(j=0\) 转移而来。我目前还没看到比较好的解决办法,因此暴力一点直接在计算每个点的 DP 值的时候特判一下从 0 转移优不优即可。(因为没有李超线段树的题解就因为这个和机房同学被硬控一个多小时)

同时的同时,由于 \(x\) 轴的范围很大,\(pc_k\) 可能达到 \(5\times 10^{11}\) 级别,因此我们需要将 \(pc_k\) 离散化一下。注意求函数值的时候需要用原始的值而不是序号。类似的是用序号还是初始值的问题需要特别注意一下。

最后有一个小问题。这个数据范围应该是会爆 \(long long\) 的吧,那为什么没有一个人写 \(\_\_int128\) 呢?

#include<bits/stdc++.h>
using namespace std;
#define int long long 
const int N=5e4+7;const __int128 inf=1e36;
int n,L,a[N],pc[N],f[N],tmp[N],loc[N],tr[N*4];
struct node{__int128 k,b;}seg[N];
__int128 get(int id,int x){return (__int128)seg[id].k*x+seg[id].b;}
bool cmp(int u,int v,int x){return get(u,x)<get(v,x);}
#define ls (u<<1)
#define rs (u<<1|1)
int query(int u,int l,int r,int x){
	if(l==r) return tr[u];
	int mid=(l+r)>>1,res=x<=mid?query(ls,l,mid,x):query(rs,mid+1,r,x);
	return cmp(tr[u],res,tmp[x])?tr[u]:res;
}
void update(int u,int l,int r,int x){
	int mid=(l+r)>>1;
	if(cmp(x,tr[u],tmp[mid])) swap(tr[u],x);if(l==r) return;
	if(cmp(x,tr[u],tmp[l])) update(ls,l,mid,x);
	if(cmp(x,tr[u],tmp[r])) update(rs,mid+1,r,x);
}
signed main(){
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n>>L;L++;for(int i=1;i<=n;i++) cin>>pc[i],pc[i]+=pc[i-1]+1,tmp[i]=pc[i];
	sort(tmp+1,tmp+n+1);int len=unique(tmp+1,tmp+n+1)-(tmp+1);for(int i=1;i<=n;i++) loc[i]=lower_bound(tmp+1,tmp+len+1,pc[i])-tmp;
	seg[0].b=inf;
	for(int i=1;i<=n;i++){
		int j=query(1,1,len,loc[i]);
		f[i]=min((__int128)(get(j,pc[i])+(__int128)pc[i]*pc[i]-(__int128)2ll*pc[i]*L+L*L),(__int128)(pc[i]-L)*(pc[i]-L));
		seg[i]={(__int128)-2ll*pc[i],(__int128)f[i]+(__int128)pc[i]*pc[i]+(__int128)2ll*pc[i]*L};
		update(1,1,len,i);
	}
	cout<<f[n]<<'\n';return 0;
}

P5785 [SDOI2012] 任务安排

这是一个系列题目,可以去找一下都做了。
这种板子题基本都是比较套路的,化简出 DP 转移方程式就差不多了。比较难的可能会与其他东西结合。(或者像这道题一样初始的 \(n^2\) 的转移就比较难想?)
这里式子同样省略掉的取 \(\min\) 的符号。朴素的 DP 方程:

\[\begin{aligned} f_i = f_j + S \times (sc_n - sc_j) + st_i \times (sc_i - sc_j) \end{aligned} \]

转化一下:

\[\begin{aligned} f_i &= \underbrace{-sc_j }_{\text{只与j有关}}\times \underbrace{st_i}_{\text{只与i有关}} + \underbrace{f_j - S \times sc_j}_{\text{只与j有关的项}} + \underbrace{S \times sc_n+st_i \times sc_i}_{\text{无关项}} \end{aligned} \]

这里的 \(pc_k\)\(pt_k\) 就是正常的前缀了。然后就可以直接做了。

code

有点久远的代码,可能码风不太一样。

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=2e6+7,inf=1e18;
int n,s,pret[N],prec[N],idcnt=0,tmp[N],len,loc[N],f[N],tr[N];
struct node{int k,b,id;}seg[N];//cout<<"get "<<u.k<<' '<<u.b<<' '<<x<<'\n';
int get(node u,int x){return u.k*x+u.b;}
#define ls (u<<1)
#define rs (u<<1|1)
int query(int u,int l,int r,int x){
	if(l==r) return get(seg[tr[u]],tmp[x]);
	int mid=(l+r)>>1,t,res2;node res1=seg[tr[u]];
	res2=x<=mid?query(ls,l,mid,x):query(rs,mid+1,r,x);
	return min(get(res1,tmp[x]),res2);
}
void update(int u,int l,int r,int id){
	int mid=(l+r)>>1;
	if(get(seg[tr[u]],tmp[mid])>get(seg[id],tmp[mid])) swap(tr[u],id);
	if(l==r) {return;}
	if(get(seg[tr[u]],tmp[l])>get(seg[id],tmp[l])){update(ls,l,mid,id);}
	if(get(seg[tr[u]],tmp[r])>get(seg[id],tmp[r])){update(rs,mid+1,r,id);}
} 
signed main(){
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n>>s;
	for(int i=1,t,c;i<=n;i++) cin>>t>>c,pret[i]=pret[i-1]+t,prec[i]=prec[i-1]+c,tmp[i]=pret[i];
	sort(tmp+1,tmp+n+1);len=unique(tmp+1,tmp+n+1)-(tmp+1);for(int i=1;i<=n;i++) loc[i]=lower_bound(tmp+1,tmp+len+1,pret[i])-tmp;
	seg[0]={0,inf,0};
	for(int i=1;i<=n;i++) {
		int k=query(1,1,len,loc[i]);
		f[i]=min(pret[i]*prec[i]+s*prec[n],k+pret[i]*prec[i]);
		seg[i]={-prec[i],f[i]+s*(prec[n]-prec[i]),i};
		update(1,1,len,i);
	}
	cout<<f[n]<<'\n';return 0;
}

P5504 [JSOI2011] 柠檬

这个东西主要是多了一个动态开点,实际上没什么难度。

省略掉取 \(\max\),有转移

\[f_i=f_{j-1}+s_j(t_i-t_j+1)^2 \]

其中 \(t_i\) 表示前 \(i\) 个位置有多少个数的值与 \(s_i\) 相同。而展开后的式子有点复杂,但是很直观,直接按上面的做即可。

但是发现有 \(10^4\) 种值,因此我们要开 \(10^4\) 棵李超线段树。直接开显然不太可能,因此我们要动态开点。用类似于主席树的写法直接写就可以了。

code

写这道题的时候有一段时间没写斜率优化了。一遍过有点感动到我了。

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e6+7;
int n,s[N],t[N],buc[N],f[N],tmp[N],loc[N],idcnt=0,rt[N],ls[N],rs[N],tr[N];
struct node{int k,b;}seg[N];
int get(int u,int x){return seg[u].k*x+seg[u].b;}
bool cmp(int u,int v,int x){return get(u,x)>get(v,x);}
void update(int &u,int l,int r,int x){
	if(!u) u=++idcnt;int mid=(l+r)>>1;
	if(cmp(x,tr[u],tmp[mid])) swap(tr[u],x);
	if(cmp(x,tr[u],tmp[l])) update(ls[u],l,mid,x);
	if(cmp(x,tr[u],tmp[r])) update(rs[u],mid+1,r,x);
}
int query(int u,int l,int r,int x){
	if(l==r) return tr[u];
	int mid=(l+r)>>1,res=x<=mid?query(ls[u],l,mid,x):query(rs[u],mid+1,r,x);
	return cmp(tr[u],res,tmp[x])?tr[u]:res;
}
signed main(){
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n;for(int i=1;i<=n;i++) cin>>s[i],t[i]=++buc[s[i]],tmp[i]=t[i];
	sort(tmp+1,tmp+n+1);int len=unique(tmp+1,tmp+n+1)-(tmp+1);for(int i=1;i<=n;i++) loc[i]=lower_bound(tmp+1,tmp+len+1,t[i])-tmp;
	seg[0]={0,(int)-1e18};
	for(int i=1;i<=n;i++){
		f[i]=f[i-1]+s[i];if(t[i]==1){seg[i]={-2*t[i]*s[i],-2*t[i]*s[i]+t[i]*t[i]*s[i]+f[i-1]};update(rt[s[i]],1,len,i);continue;}
		f[i]=max(f[i],get(query(rt[s[i]],1,len,loc[i]),t[i])+s[i]*t[i]*t[i]+2*s[i]*t[i]+s[i]);
		seg[i]={-2*t[i]*s[i],-2*t[i]*s[i]+t[i]*t[i]*s[i]+f[i-1]};update(rt[s[i]],1,len,i);
	}
	cout<<f[n];return 0;
}

P9020 [USACO23JAN] Mana Collection P

比较综合的一道题,代码难度相对不高思维链条其实也比较清晰,但是需要一些做题的套路不然想不到下一步如何去做。

发现我们并不知道起点,因此考虑从终点倒着来走。由于一个点的贡献只由其最后一次被走的时间决定,因此只要我们知道了我们倒着走每个点被经过的顺序,我们就可以唯一确定其贡献。
但是我们显然不能够去直接枚举这个顺序,否则就会是 \(\sum {n\choose i}i!\) 的巨额复杂度(还要枚举其子集的顺序)。考虑一下如何不显式地构造这个顺序去计算其贡献。

我们反着来考虑。我们去枚举一个子集,然后去计算其损失的贡献。每次询问的答案也就是 \(t\times \sum m_i-f\),其中 \(f\) 是最小的损失,其中 \(t\) 表示询问输入进来的时间。
于是我们直接状压 DP 设 \(f_{s,u}\) 表示已经走过的点集为 \(s\),此时在 \(u\) 结尾的最小损失。
于是有转移

\[f_{s,u}=\min_{v\in s} f_{s-u,v}+w_{s-u}\times dis_{v,u} \]

其中 \(s-u\) 表示除去 \(u\) 后的点集,\(w_{s-u}\) 表示除去 \(u\) 之后其他在点集中的点的 \(m\) 的和,\(dis_{v,u}\) 是用 floyd 算出来的最短路。
发现 \(t\times \sum m_i-f\) 在知道了点集的点以及终点的情况是一个一次函数的形式,于是对于每一个终点去建一棵李超线段树即可。

code

实现的时候将 \(m_i\) 写作 \(a_i\)

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=20,M=5e6+7,O=2e5+7,inf=1e18+7;
int mp[N][N],n,m,a[N],f[1<<18][N],tmp[N][O],len[N];
void floyd(){
	for(int k=1;k<=n;k++){
		mp[k][k]=0;
		for(int i=1;i<=n;i++)
		for(int j=1;j<=n;j++) mp[i][j]=min(mp[i][j],mp[i][k]+mp[k][j]);
	}
}
struct node{int k,b;}seg[M],que[O];
#define ls (u<<1)
#define rs (u<<1|1)
#define mid ((l+r)>>1)
int tr[N][M],segcnt=0;
int get(int id,int x,int s){return seg[id].k*tmp[s][x]+seg[id].b;}
int cmp(int u,int v,int x,int s){return get(u,x,s)>get(v,x,s);}
void update(int u,int l,int r,int x,int s){
	if(cmp(x,tr[s][u],mid,s))swap(tr[s][u],x);
	if(cmp(x,tr[s][u],l,s))update(ls,l,mid,x,s);
	if(cmp(x,tr[s][u],r,s))update(rs,mid+1,r,x,s);
}
int query(int u,int l,int r,int x,int s){
	if(l==r)return tr[s][u];
	int res=x<=mid?query(ls,l,mid,x,s):query(rs,mid+1,r,x,s);
	return cmp(res,tr[s][u],x,s)?res:tr[s][u];
}
signed main(){
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n>>m;for(int i=1;i<=n;i++)cin>>a[i];
	memset(mp,0x3f3f,sizeof(mp));
	for(int i=1,u,v,w;i<=m;i++)cin>>u>>v>>w,mp[u][v]=min(mp[u][v],w);
	floyd();int S=(1<<n)-1;
	memset(f,0x3f3f,sizeof(f));for(int i=1;i<=n;i++)f[1<<(i-1)][i]=0;
	for(int s=1;s<=S;s++){
		int w=0;for(int i=0;i<n;i++)if(s&(1<<i))w+=a[i+1];
		for(int i=1;i<=n;i++){
			for(int j=1;j<=n;j++){
				if(i==j||(!(s&(1<<(j-1))))||(!(s&(1<<(i-1))))||(mp[j][i]>inf))continue;
				int w1=w-a[i];f[s][i]=min(f[s][i],f[s^(1<<(i-1))][j]+w1*mp[j][i]);
			}
		}
	}
	int Q;cin>>Q;
	for(int i=1;i<=Q;i++)cin>>que[i].k>>que[i].b,tmp[que[i].b][++len[que[i].b]]=que[i].k;
	for(int i=1;i<=n;i++)sort(tmp[i]+1,tmp[i]+len[i]+1),len[i]=unique(tmp[i]+1,tmp[i]+len[i]+1)-(tmp[i]+1);
	for(int i=1;i<=Q;i++)que[i].k=lower_bound(tmp[que[i].b]+1,tmp[que[i].b]+len[que[i].b]+1,que[i].k)-tmp[que[i].b];
	seg[0]={0ll,-inf};
	for(int s=1;s<=S;s++){
		int k=0;for(int i=0;i<n;i++)if(s&(1<<i))k+=a[i+1];
		for(int i=1;i<=n;i++)seg[++segcnt]={k,-f[s][i]},update(1,1,len[i],segcnt,i);
	}
	for(int i=1;i<=Q;i++) cout<<get(query(1,1,len[que[i].b],que[i].k,que[i].b),que[i].k,que[i].b)<<'\n';
	return 0;
}

单调队列实现

单调队列实现本身的复杂度是均摊 \(O(n)\)。其更优秀的复杂度是由更严格的限制得到的。下面我们来描述一下限制:
一般我们得到的式子长这样:

\[f_i=f_j+w(j,i) \]

我们现在变成这样:

\[f_j=f_i-w(j,i) \]

这个过程相当于李超线段树的反过程。李超线段树是正着直接考虑转移求值,而单调队列则将 \(f_j\) 当作纵坐标来做,而将 \(i\) 看作直线。
而其所谓“更加严格的限制”指代的是:新增的 \(i\) 的横坐标必定单增同时斜率也必定单增。(显然李超线段树不用考虑这么多)
其实现就是维护转移点。由于横坐标与斜率都单增,因此如果横坐标较小的点如果某个时刻不优那其就不可能优了,直接扔掉即可。

posted @ 2025-03-30 15:52  all_for_god  阅读(37)  评论(0)    收藏  举报