斜率优化DP
状态转移方程形如
其中最重要的特征是 \(i,j\) 交乘项,min和max同理,下文以min为例进行说明,我们把方程整理为如下形式
注意为了方便理解与上文字母不同,也就是所有 只 与 \(i\) 有关的项提到 min 前,与 \(i,j\) 都 有关的项放在一起,所有 只 与 \(j\) 有关的项放在一起。
考虑对于 \(j,k\) 两个决策点,不妨设 \(x_j<x_k\)
如果 \(k\) 优于 \(j\) ,那么有
到这里不难发现,如果我们把所有决策点的 \((x,y)\) 扔到坐标系上,那么真正有效的决策点之间斜率是递增的,也就是说我们要维护一个决策点的下凸壳,而要求的值相当于一个截距,对于 \(i\) 的决策,相当于拿一条斜率为 \(-k_i\) 的直线去切这个凸壳,切到的点即最优决策点。
考虑维护细节,对于最朴素的情况,我们需要支持在凸壳上任意插入删除以及在凸壳上查找第一个大于某斜率的点。对于前者,考虑用一棵平衡树进行维护,对于后者,可以直接进行二分。
进一步的,当坐标单调时,相当于我们只在端点处插入新元素,如 \(x\) 递增时,直接维护一个单调队列即可。当斜率单调时,相当于决策点也是单调移动的,做一个类似双指针的东西,如 \(-k\) 递增时,决策点单调右移,不断弹出队首元素,维护队首作为最优决策点即可。
(事实上大部分题应该都有单调性)
对于max而言同理,决策点之间斜率是递减的,我们维护一个上凸壳即可。
另外一种理解角度,我们把方程整理成如下形式
\(f_i = w_i+\min k_jx_i+b_j\)
相当于我们把每一个决策看作一个一次函数,那么我们要求的就是在 \(x\) 处若干一次函数的最值,显然可以用李超树进行维护。
该方法的一大限制是只能维护整数坐标,对于实数坐标或者大整数坐标,需要考虑能否进行离散化之后维护。
不难发现法二相较于法一细节更少一些,我们无脑把一次函数扔到李超树上就完事了,而法一则需要考虑一下 corner case,比如当 \(x\) 坐标相等时,斜率要设成极大值/极小值。
个人偏好是有单调性的时候用法一维护凸壳,没有单调性就直接李超树,不想写平衡树维护凸壳。
P3195 [HNOI2008] 玩具装箱
应该是最经典的例题
设 \(f_i\) 表示装好前 \(i\) 个且最后一组以 \(i\) 结尾的最小代价,则有 \(O(n^2)\) 的朴素转移方程
记 \(sum_k = \sum_{i=1}^k c_i+1\),\(L=L+1\)
则有
于是和上文的形式是完全相同的,我们把 \((sum_j , sum_j^2 + 2sum_jL +f_j)\) 看成坐标扔到坐标系上维护下凸壳,用斜率为 \(2sum_i\) 的直线去切这个凸壳即可,注意到此时坐标和斜率均单调,可以直接拿单调队列进行维护。
代码:
#include<bits/stdc++.h>
#define MAXN 50005
#define LL long long
#define look_memory cerr<<abs(&M2-&M1)/1024.0/1024<<'\n'
#define look_time cerr<<(clock()-Time)*1.0/CLOCKS_PER_SEC<<'\n'
using namespace std;
inline int read(){
int x=0;
int f=1;
char c=getchar();
while(c<'0' || c>'9'){
if(c=='-') f=-1;
c=getchar();
}
while(c>='0' && c<='9'){
x=(x<<1)+(x<<3)+(c^48);
c=getchar();
}
return x*f;
}
bool M1;
int n,L;
int c[MAXN];
LL sum[MAXN];
LL f[MAXN];
struct node{
LL x,y;
};
deque<node> q;
bool M2;
double get_k(node a,node b){
return 1.0*(a.y-b.y)/(a.x-b.x);
}
int main(){
// freopen("","r",stdin);
// freopen("","w",stdout);
int Time=clock();
n=read();L=read();
L++;
for(int i=1;i<=n;i++){
c[i]=read();
sum[i]=sum[i-1]+c[i]+1;
}
memset(f,0x3f,sizeof(f));
f[0]=0;
q.push_back((node){0,0});
for(int i=1;i<=n;i++){
LL k=2*sum[i];
while(q.size()>=2 && k>get_k(q[0],q[1])) q.pop_front();
LL x=q.front().x,y=q.front().y;
f[i]=-2*sum[i]*x+y+(sum[i]-L)*(sum[i]-L);
node now=(node){sum[i],f[i]+(2*L+sum[i])*sum[i]};
while(q.size()>=2 && get_k(q[q.size()-2],now)<get_k(q[q.size()-2],q.back())) q.pop_back();
q.push_back(now);
}
printf("%lld\n",f[n]);
look_memory;
look_time;
return 0;
}
P5785 [SDOI2012] 任务安排
考虑每批任务启动前有一个准备时间s,由于我们不希望再加一维记录此时执行了几批任务,所以直接做一个费用提前计算,也就是每次分组相当于对于后面所有任务都增加了一个 \(s*c\) 的代价。
设 \(f_i\) 表示考虑了前 \(i\) 个任务,且最后一组任务以 \(i\) 结尾的最小代价。
记 \(st_k = \sum_{i=1}^k t_i\),\(sum_k = \sum_{i=1}^k c_i\)
则有
以 \((sum_i,f_i)\) 为坐标维护下凸壳,用斜率为 \(st_i+s\) 的直线去切。
注意到此题时间可以为负数,也就是说斜率不单调,需要在凸壳上二分找到决策点,这也是我把此题作为例题的意义所在。
代码:
#include<bits/stdc++.h>
#define MAXN 300005
#define int long long
#define look_memory cerr<<abs(&M2-&M1)/1024.0/1024<<'\n'
#define look_time cerr<<(clock()-Time)*1.0/CLOCKS_PER_SEC<<'\n'
using namespace std;
inline int read(){
int x=0;
int f=1;
char c=getchar();
while(c<'0' || c>'9'){
if(c=='-') f=-1;
c=getchar();
}
while(c>='0' && c<='9'){
x=(x<<1)+(x<<3)+(c^48);
c=getchar();
}
return x*f;
}
bool M1;
int n,s;
int t[MAXN],c[MAXN],st[MAXN],sum[MAXN];
int f[MAXN];
struct node{
int x,y;
};
deque<node> q;
double get_k(node aa,node bb){
return 1.0*(bb.y-aa.y)/(bb.x-aa.x);
}
node find(int k){
int l=0,r=q.size()-1,mid,res=q.size()-1;
while(l<=r){
mid=(l+r)>>1;
if(k>=get_k(q[mid],q[mid+1])){
l=mid+1;
}else{
r=mid-1;
res=mid;
}
}
return q[res];
}
bool M2;
signed main(){
// freopen("","r",stdin);
// freopen("","w",stdout);
int Time=clock();
n=read();s=read();
for(int i=1;i<=n;i++){
t[i]=read();
c[i]=read();
st[i]=st[i-1]+t[i];
sum[i]=sum[i-1]+c[i];
}
q.push_back((node){0,0});
for(int i=1;i<=n;i++){
node res=find(st[i]+s);
int x=res.x,y=res.y;
f[i]=st[i]*sum[i]+s*sum[n]-(st[i]+s)*x+y;
node tmp={sum[i],f[i]};
while(q.size()>=2 && get_k(q[q.size()-2],tmp)<=get_k(q[q.size()-2],q.back())) q.pop_back();
q.push_back(tmp);
}
int ans=f[n];
printf("%lld\n",ans);
look_memory;
look_time;
return 0;
}
P4056 [JSOI2009] 火星藏宝图
第一思路是设 \(f_i\) 表示走到第 \(i\) 个岛上的最大收益,有转移方程
注意到方程中有两个 \(i,j\) 交乘项,无法直接进行斜率优化,但是不难发现坐标的数据范围很小,我们可以考虑枚举一维坐标将其变为常量,然后再用斜率优化去做另一维,这样复杂度是 \(O(nm)\) 的。
但其实我们有复杂度更加优秀的做法,需要观察一些性质。
设 \(f_{i,j}\) 表示 走到坐标为 \((i,j)\) 的最大收益。
朴素转移方程是 \(O(m^4)\) 的,枚举当前坐标,枚举决策点坐标,然后进行转移。
考虑对于同一列上的两个转移点 \(j,k\) 。

1号路线的代价为 \((x_i-x_k)^2+(y_i-y_k)^2\)
2号路线的代价为 \((x_j-x_k)^2 + (x_i-x_j)^2+(y_i-y_j)^2\)
由于 \((x_j-x_k)+(x_i-x_j) = x_i-x_k\) 不难发现有2号路线优于1号路线。
即每一列上最优的转移点是行坐标最大的点,于是确定了决策点的列也就确定了行,复杂度优化为 \(O(m^3)\)。
我们记第 i 列的最优转移点为 \(pos_i,i\)。
则有转移方程
由于行是固定的,为了方便下文记在第 i 行 \(dis_j=(i-pos_j)^2\)。
以 \((k,f_{pos_k,k} - dis_k -k^2)\) 为坐标维护上凸壳,用斜率为 \(-2j\) 的直线去切,注意到坐标和斜率均单调。
复杂度优化为 \(O(m^2)\)
将此题作为例题之一的意义是,感受一下维护上凸壳的做法,以及本题的性质转化很有意思。
代码:
#include<bits/stdc++.h>
#define MAXN 200005
#define MAXM 1005
#define int long long
#define look_memory cerr<<abs(&M2-&M1)/1024.0/1024<<'\n'
#define look_time cerr<<(clock()-Time)*1.0/CLOCKS_PER_SEC<<'\n'
using namespace std;
inline int read(){
int x=0;
int f=1;
char c=getchar();
while(c<'0' || c>'9'){
if(c=='-') f=-1;
c=getchar();
}
while(c>='0' && c<='9'){
x=(x<<1)+(x<<3)+(c^48);
c=getchar();
}
return x*f;
}
bool M1;
int n,m;
struct node{
int x,y;
};
int w[MAXM][MAXM];
deque<node> q;
int pos[MAXM],dis[MAXM],f[MAXM][MAXM];
const int inf=1e9;
const double eps=1e-6;
double get_k(node aa,node bb){
return 1.0*(bb.y-aa.y)/((aa.x==bb.x)?eps:(bb.x-aa.x));
}
bool M2;
signed main(){
// freopen("","r",stdin);
// freopen("","w",stdout);
int Time=clock();
n=read();m=read();
for(int i=1;i<=n;i++){
int x,y;
x=read();y=read();
w[x][y]=read();
}
f[1][1]=w[1][1];
pos[1]=1;
for(int i=1;i<=m;i++){
q.clear();
for(int j=1;j<=m;j++){
if(pos[j]) dis[j]=(i-pos[j])*(i-pos[j]);
else dis[j]=0;
}
for(int j=1;j<=m;j++){
if(pos[j]){
node tmp={j,f[pos[j]][j]-dis[j]-j*j};
while(q.size()>=2 && get_k(q[q.size()-2],tmp)>=get_k(q[q.size()-2],q.back())) q.pop_back();
q.push_back(tmp);
}
if(w[i][j] && (i>1||j>1)){
while(q.size()>=2 && -2*j<=get_k(q[0],q[1])) q.pop_front();
int x=q.front().x,y=q.front().y;
f[i][j]=w[i][j]-j*j+2*j*x+y;
pos[j]=i;
dis[j]=0;
node tmp={j,f[pos[j]][j]-dis[j]-j*j};
while(q.size()>=2 && get_k(q[q.size()-2],tmp)>=get_k(q[q.size()-2],q.back())) q.pop_back();
q.push_back(tmp);
}
}
}
int ans=f[m][m];
printf("%lld\n",ans);
look_memory;
look_time;
return 0;
}
P6302 [NOI2019] 回家路线 加强版
也是需要先进行一些转化再上斜率优化,比较有意思的题,由于作者懒得写了,本文暂时不进行详细解析。
在写这题的时候,注意到最后有一个枚举所有终点对于 \(ans\) 取 min,但事实上有些终点未必合法,可能从起点开始走是不可达的,由于我最开始没有给dp数组初始化成 \(inf\) 被硬控了0.5day
代码:
#include<bits/stdc++.h>
#define MAXN 100005
#define MAXM 1000005
#define int long long
#define look_memory cerr<<abs(&M2-&M1)/1024.0/1024<<'\n'
#define look_time cerr<<(clock()-Time)*1.0/CLOCKS_PER_SEC<<'\n'
using namespace std;
inline int read(){
int x=0;
int f=1;
char c=getchar();
while(c<'0' || c>'9'){
if(c=='-') f=-1;
c=getchar();
}
while(c>='0' && c<='9'){
x=(x<<1)+(x<<3)+(c^48);
c=getchar();
}
return x*f;
}
bool M1;
int n,m,A,B,C;
struct node{
int x,y;
bool operator < (node tmp) const{
if(x==tmp.x) return tmp.y>y;
return tmp.x<x;
}
};
struct po{
int s,t,ta,tb;
}a[MAXM];
int f[MAXM];
priority_queue<node> pq[MAXN];
deque<node> q[MAXN];
const int inf=1e18;
const double eps=1e-9;
bool cmp(po aa,po bb){
return aa.ta<bb.ta;
}
long double get_k(node aa,node bb){
return 1.0*(bb.y-aa.y)/((aa.x==bb.x)?eps:(bb.x-aa.x));
}
bool M2;
signed main(){
// freopen("","r",stdin);
// freopen("","w",stdout);
int Time=clock();
n=read();m=read();
A=read();B=read();C=read();
for(int i=1;i<=m;i++){
a[i].s=read();a[i].t=read();
a[i].ta=read();a[i].tb=read();
}
sort(a+1,a+1+m,cmp);
memset(f,0x3f,sizeof(f));
f[0]=0;
pq[1].push((node){0,0});
for(int i=1;i<=m;i++){
int s=a[i].s,t=a[i].t,ta=a[i].ta,tb=a[i].tb;
while(!pq[s].empty() && pq[s].top().x<=ta){
node tmp=pq[s].top();
pq[s].pop();
while(q[s].size()>=2 && get_k(q[s][q[s].size()-2],tmp)<=get_k(q[s][q[s].size()-2],q[s].back())) q[s].pop_back();
q[s].push_back(tmp);
}
int k=2*A*ta;
while(q[s].size()>=2 && k>=get_k(q[s][0],q[s][1])) q[s].pop_front();
if(!q[s].empty()){
int x=q[s].front().x,y=q[s].front().y;
f[i]=A*ta*ta+B*ta+C-2*A*ta*x+y;
node tmp={tb,f[i]+A*tb*tb-B*tb};
pq[t].push(tmp);
}
}
int ans=inf;
for(int i=1;i<=m;i++){
if(a[i].t==n){
ans=min(ans,f[i]+a[i].tb);
}
}
printf("%lld\n",ans);
look_memory;
look_time;
return 0;
}
P2497 [SDOI2012] 基站建设
对于相邻的两个被使用的基站 \(j,i\) 随便画一画图运用初中 (小学) 几何知识可以得到 \(i\) 的接受范围 \(r'_i = \frac{(x_i-x_j)^2}{4r_j}\),不详细说了。
设 \(f_i\) 表示考虑了前 \(i\) 个基站,且 \(i\) 被使用的最小代价。
则有状态转移方程
发现坐标不单调,由于不想维护平衡树,所以把决策视作斜率为 \(\frac{1}{2\sqrt{r_j}}\) 截距为 \(f_j - \frac{x_j}{2\sqrt{r_j}}\) 的一次函数,扔到李超树上维护。
注意由于此题坐标的数据范围很大,需要对 \(x\) 进行离散化。
于是本题作为例题之一的意义是感受一下李超树版的斜优。
代码:
#include<bits/stdc++.h>
#define MAXN 500005
#define int long long
#define look_memory cerr<<abs(&M2-&M1)/1024.0/1024<<'\n'
#define look_time cerr<<(clock()-Time)*1.0/CLOCKS_PER_SEC<<'\n'
using namespace std;
inline int read(){
int x=0;
int f=1;
char c=getchar();
while(c<'0' || c>'9'){
if(c=='-') f=-1;
c=getchar();
}
while(c>='0' && c<='9'){
x=(x<<1)+(x<<3)+(c^48);
c=getchar();
}
return x*f;
}
bool M1;
int n,m;
int p[MAXN],r[MAXN],v[MAXN];
int lsh[MAXN],nn;
double f[MAXN];
struct line{
double k,b;
}lin[MAXN];
const double inf=1e18;
struct lichao_tree{
#define ls k<<1
#define rs k<<1|1
int t[MAXN<<2];
double get_val(int i,int x){
if(!i) return inf;
return lin[i].k*lsh[x]+lin[i].b;
}
void update(int k,int l,int r,int x){
if(l==r){
if(get_val(x,l)<get_val(t[k],l)) t[k]=x;
return;
}
int mid=(l+r)>>1;
if(get_val(x,mid)<get_val(t[k],mid)) swap(x,t[k]);
if(get_val(x,l)<get_val(t[k],l)) update(ls,l,mid,x);
if(get_val(x,r)<get_val(t[k],r)) update(rs,mid+1,r,x);
}
double query(int k,int l,int r,int x){
if(l==r) return get_val(t[k],x);
int mid=(l+r)>>1;
if(mid>=x) return min(get_val(t[k],x),query(ls,l,mid,x));
else return min(get_val(t[k],x),query(rs,mid+1,r,x));
}
#undef ls
#undef rs
}ST;
bool M2;
signed main(){
// freopen("","r",stdin);
// freopen("","w",stdout);
int Time=clock();
n=read();m=read();
for(int i=1;i<=n;i++){
p[i]=read();
r[i]=read();
v[i]=read();
lsh[i]=p[i];
}
sort(lsh+1,lsh+1+n);
nn=unique(lsh+1,lsh+1+n)-lsh-1;
for(int i=1;i<=n;i++){
p[i]=lower_bound(lsh+1,lsh+1+nn,p[i])-lsh;
}
f[1]=1.0*v[1];
lin[1]=(line){1.0/(2*sqrt(r[1])),f[1]-1.0*lsh[p[1]]/(2*sqrt(r[1]))};
ST.update(1,1,n,1);
for(int i=2;i<=n;i++){
double tmp=ST.query(1,1,n,p[i]);
f[i]=tmp+v[i];
lin[i]=(line){1.0/(2*sqrt(r[i])),f[i]-1.0*lsh[p[i]]/(2*sqrt(r[i]))};
ST.update(1,1,n,i);
}
double ans=inf;
for(int i=n;i>=1;i--){
if(m-lsh[p[i]]<=r[i]) ans=min(ans,f[i]);
}
printf("%.3lf\n",ans);
look_memory;
look_time;
return 0;
}

浙公网安备 33010602011771号