斜率优化
当 dp
长成形如 \(f_i=\min\limits_{j<i}/\max\limits_{j<i}\{a_i\times b_j+c_i+d_j\}\) 的形式时,因为存在既和 \(i\) 有关也和 \(j\) 有关的部分,所以考虑进行斜率优化。
转化一下式子得 \(f_i-c_i-a_i\times b_j=d_j\),这个式子长得很像一次函数 \(y=kx+b\)。
这里引用 oi-wiki
中的一张图,
将 \((b_j,d_j)\) 看作一个点,整个 dp
的过程相当于有一条直线,在平面上找到一个点,使得直线经过该点时的截距是所有点中最小/最大的。很容易发现这样点只可能出现在凸包上,因此操作便转化为了:
- 用一条斜率为 \(-a_i\) 的直线去截一个凸包,找到最小的截距。
- 将一个点加入到凸包中并维护这个凸包。
[HNOI2008]玩具装箱
这道题算是一道比较简单的斜率优化典型题了。
令 \(s_i=\sum\limits_{j=1}^i c_j\),
拆开重新组合一下,
再令 \(sum_i=s_i+i,L'=L+1\)
拆掉式子,
将 \(\min\) 先丢开,将只和 \(j\) 有关系的放在右边,
此时 \(b=f_i-sum_i^2,k=2sum_i,x=sum_j+L',y=f_j+(sum_j+L')^2\)
根据这个可以知道,\(k\) 是递增的,相邻两点之间线段的斜率也是递增,因此可以直接用单调队列维护。
点击查看代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=5e4+10;
int n,l,q[maxn],a[maxn],b[maxn],x[maxn],y[maxn],s[maxn],f[maxn];
int read(){
int s=0,f=1;
char ch=getchar();
while(ch>'9'||ch<'0'){
if(ch=='-'){
f=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9'){
s=s*10+(ch-'0');
ch=getchar();
}
return s*f;
}
double get_slope(int s,int t){
return 1.0*(y[s]-y[t])/(x[s]-x[t]);
}
signed main(){
n=read(),l=read();
l++;
x[0]=b[0]=l;
y[0]=b[0]*b[0];
for(int i=1,sum;i<=n;i++){
cin>>sum;
s[i]=s[i-1]+sum;
a[i]=s[i]+i;
x[i]=b[i]=a[i]+l;
}
int head=1,tail=0;
for(int i=1;i<=n;i++){
while(head<tail&&get_slope(q[head],q[head+1])<a[i]*2.0){
head++;
}
f[i]=f[q[head]]+(a[i]-b[q[head]])*(a[i]-b[q[head]]);
y[i]=f[i]+b[i]*b[i];
while(head<tail&&get_slope(i,q[tail-1])<get_slope(q[tail-1],q[tail])){
tail--;
}
q[++tail]=i;
}
printf("%lld",f[n]);
return 0;
}
[SDOI2016]征途
化简 \(v\times m^2\) 可以得到
可以发现后半部分为常数,只需要计算前半部分的最小值。
设 \(f_{i,k}\) 表示到第 \(i\) 段路已经走了 \(k\) 天,\(\sum\limits_{l=1}^k x_l^2\) 的最小值。令 \(s_i\) 表示第 \(1\) 段到第 \(i\) 段路长度的前缀和,可以写出:
将平方拆开,去掉 \(\min\),将只和 \(j\) 有关的部分放到右边,得到
此时 \(b=f_{i,k}-s_i^2,k=2 s_i,x=s_j,y=f_{j,k-1}+s_j^2\)
\(k\) 单调递增,相邻两点间的斜率单调递增,单调队列维护。
先枚举第二维,再枚举第一维即可做。
点击查看代码
#include<bits/stdc++.h>
#define ull unsigned long long
#define ll long long
#define pdi pair<double,int>
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define eps 1e-9
using namespace std;
namespace IO{
template<typename T>
inline void read(T &x){
x=0;
int f=1;
char ch=getchar();
while(ch>'9'||ch<'0'){
if(ch=='-'){
f=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9'){
x=x*10+(ch-'0');
ch=getchar();
}
x=(f==1?x:-x);
}
template<typename T>
inline void write(T x){
if(x<0){
putchar('-');
x=-x;
}
if(x>=10){
write(x/10);
}
putchar(x%10+'0');
}
template<typename T>
inline void write_endl(T x){
write(x);
putchar('\n');
}
template<typename T>
inline void write_space(T x){
write(x);
putchar(' ');
}
}
using namespace IO;
const int N=3e3+10;
int n,m,f[N],g[N],x[N],y[N],q[N],s[N];
double k[N];
double get_slope(int a,int b){
return 1.0*(y[b]-y[a])/(x[b]-x[a]);
}
signed main(){
#ifndef ONLINE_JUDGE
freopen("1.in","r",stdin);
freopen("1.out","w",stdout);
#endif
read(n),read(m);
for(int i=1;i<=n;i++){
read(s[i]);
s[i]+=s[i-1];
y[i]=2*s[i]*s[i];
k[i]=2.0*s[i];
x[i]=s[i];
g[i]=s[i]*s[i];
}
for(int i=1;i<m;i++){
int head=1,tail=1;
q[1]=i;
for(int j=i+1;j<=n;j++){
while(head<tail&&get_slope(q[head],q[head+1])<k[j]){
head++;
}
f[j]=g[q[head]]+(s[j]-s[q[head]])*(s[j]-s[q[head]]);
while(head<tail&&get_slope(q[tail-1],q[tail])>get_slope(q[tail],j)){
tail--;
}
q[++tail]=j;
}
for(int j=1;j<=n;j++){
g[j]=f[j];
y[j]=g[j]+s[j]*s[j];
}
}
write_endl(f[n]*m-s[n]*s[n]);
return 0;
}
[SDOI2012]任务安排
令 \(f_{i,j}\) 表示到 \(i\) 分为 \(j\) 批的最小费用,\(Sc_i\) 表示费用的前缀和,\(St_i\) 表示所费时间的前缀和。
但这个是 \(O(n^3)\),考虑优化状态,可以发现一个 \(s\) 是会影响区间 \(\left[j,n\right]\) 的,令 \(f_i\) 表示前 \(i\) 个任务分为若干批的最小费用。
考虑斜率优化。和前面一样的,拆开,去掉 \(\min\),把只和 \(j\) 有关的放在右边,得到:
在这个式子中 \(b=f_i-sSc_n-St_i Sc_i,k=s+St_i,x=Sc_j,y=f_j\)。因为要求的是 \(b\) 的最小值,所以维护下凸壳;与前面不同的是,虽然 \(x\) 单调递增,但是 \(k\) 不一定会单调递增,因此虽然能使用单调队列维护凸壳,但不能每次取队首更新答案,需要在凸壳上二分找点。
点击查看代码
#include<bits/stdc++.h>
#define ull unsigned long long
#define int long long
#define pdi pair<double,int>
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define eps 1e-9
using namespace std;
namespace IO{
template<typename T>
inline void read(T &x){
x=0;
int f=1;
char ch=getchar();
while(ch>'9'||ch<'0'){
if(ch=='-'){
f=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9'){
x=x*10+(ch-'0');
ch=getchar();
}
x=(f==1?x:-x);
}
template<typename T>
inline void write(T x){
if(x<0){
putchar('-');
x=-x;
}
if(x>=10){
write(x/10);
}
putchar(x%10+'0');
}
template<typename T>
inline void write_endl(T x){
write(x);
putchar('\n');
}
template<typename T>
inline void write_space(T x){
write(x);
putchar(' ');
}
}
using namespace IO;
const int N=3e5+10;
int n,s,cnt_c[N],cnt_t[N],f[N],q[N];
int x[N],y[N],k[N];
int head=1,tail=1;
int find(int pos){
int l=head,r=tail;
while(l<r){
int mid=(l+r)>>1;
if(y[q[mid+1]]-y[q[mid]]<=k[pos]*(x[q[mid+1]]-x[q[mid]])){
l=mid+1;
}
else{
r=mid;
}
}
return q[r];
}
int cross(int p1,int p2,int p3){
return (x[p2]-x[p1])*(y[p3]-y[p2])-(y[p2]-y[p1])*(x[p3]-x[p2]);
}
signed main(){
#ifndef ONLINE_JUDGE
freopen("1.in","r",stdin);
freopen("1.out","w",stdout);
#endif
read(n),read(s);
for(int i=1;i<=n;i++){
int c,t;
read(t),read(c);
cnt_c[i]=cnt_c[i-1]+c;
cnt_t[i]=cnt_t[i-1]+t;
x[i]=cnt_c[i];
k[i]=cnt_t[i]+s;
}
memset(f,0x3f,sizeof(f));
f[0]=0;
q[head]=0;
for(int i=1;i<=n;i++){
int pos=find(i);
f[i]=f[pos]+s*(cnt_c[n]-cnt_c[pos])+cnt_t[i]*(cnt_c[i]-cnt_c[pos]);
y[i]=f[i];
while(head<tail&&cross(q[tail-1],q[tail],i)<=0){
tail--;
}
q[++tail]=i;
}
write_endl(f[n]);
return 0;
}
Yet Another Partiton Problem
斜率优化好题
二维 dp
方程很容易写出来,令 \(f_{j,i}\) 表示到第 \(i\) 个数分成了 \(j\) 段,
但这个方程是 \(O(n^2k)\) 的,显然过不去此题。因为贡献函数不满足四边形不等式(可以自证),所以决策单调性分治是不行的,考虑斜率优化。
整个方程最格格不入的是后面的 \(\max\),很难处理,但转换一下思路。固定 \(i\),令 \(v_j=\max\limits_{l=j+1}^i a_l\),枚举第一维,方程变为:
其中 \(f\) 表示原方程中的 \(f_k\),\(g\) 表示原方程中的 \(f_{k-1}\)。
容易发现 \(v_j\) 维护的是后缀 \(\max\),是单调递减的。考虑 \(v\) 值相同的一段区间 \(\left[l,r\right]\),只需求出 \(g_j-v\times j\) 的最小值,这个式子中 \(b=g_j-v\times j,k=v,y=g_j,x=j\),通过切凸包得到最小的 \(b\)。现在得到的 \(f_i=b+v\times i\),长得和直线方程很像,考虑用李超树维护,每次询问在某个位置上的最小值。
现在再想一下一些细节,可以令第 \(i\) 根直线的 \(v_i=a_i\),因为前面每根直线都已经得到,所以每次只需要得到第 \(i\) 根直线即可。又因为每次会影响的 \(v\) 是一段后缀 \((x,i]\),\(x\) 是在 \(i\) 之前第一个满足 \(a_x>a_i\) 的数,可以用单调栈维护得到,然后合并 \((x,i]\) 所在凸包,在 \(x\) 的李超树上加上一条直线,用可持久化李超线段树维护。
点击查看代码
#include<bits/stdc++.h>
#define ull unsigned long long
#define int long long
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
using namespace std;
namespace IO{
template<typename T>
inline void read(T &x){
x=0;
int f=1;
char ch=getchar();
while(ch>'9'||ch<'0'){
if(ch=='-'){
f=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9'){
x=x*10+(ch-'0');
ch=getchar();
}
x=(f==1?x:-x);
}
template<typename T>
inline void write(T x){
if(x<0){
putchar('-');
x=-x;
}
if(x>=10){
write(x/10);
}
putchar(x%10+'0');
}
template<typename T>
inline void write_endl(T x){
write(x);
putchar('\n');
}
template<typename T>
inline void write_space(T x){
write(x);
putchar(' ');
}
}
using namespace IO;
const int N=2e4+10;
int n,m,b[N],a[N],k[N],g[N],f[N];
int stk[N],top;
deque<int>q[N];
int cross(int i,int j,int k){
return (j-i)*(g[k]-g[j])-(k-j)*(g[j]-g[i]);
}
void merge(int x,int y){
if(q[x].size()<q[y].size()){
while(q[x].size()){
while(q[y].size()>1&&cross(q[x].back(),q[y].front(),q[y][1])<=0){
q[y].pop_front();
}
q[y].push_front(q[x].back());
q[x].pop_back();
}
}
else{
while(q[y].size()){
while(q[x].size()>1&&cross(q[x][q[x].size()-2],q[x].back(),q[y].front())<=0){
q[x].pop_back();
}
q[x].push_back(q[y].front());
q[y].pop_front();
}
q[y].swap(q[x]);
}
}
int ask(int k,int id){
int l=0,r=q[id].size()-1;
while(l<r){
int mid=(l+r)>>1;
int x=q[id][mid],y=q[id][mid+1];
if((g[y]-g[x])<=(y-x)*k){
l=mid+1;
}
else{
r=mid;
}
}
return g[q[id][l]]-k*q[id][l];
}
int rt[N],cnt;
struct node{
int ch[2],mn;
}tr[N<<5];
int get(int x,int id){
return k[id]*x+b[id];
}
#define ls(p) tr[p].ch[0]
#define rs(p) tr[p].ch[1]
void update(int &p,int pre,int l,int r,int u){
p=++cnt;
tr[p]=tr[pre];
int mid=(l+r)>>1;
if(get(mid,tr[p].mn)>get(mid,u)){
swap(u,tr[p].mn);
}
if(get(l,tr[p].mn)>get(l,u)){
update(ls(p),ls(pre),l,mid,u);
}
else if(get(r,tr[p].mn)>get(r,u)){
update(rs(p),rs(pre),mid+1,r,u);
}
}
int query(int p,int l,int r,int pos){
if(!p||l==r){
return get(pos,tr[p].mn);
}
int mid=(l+r)>>1;
if(pos<=mid){
return min(query(ls(p),l,mid,pos),get(pos,tr[p].mn));
}
else{
return min(query(rs(p),mid+1,r,pos),get(pos,tr[p].mn));
}
}
void solve(){
read(n),read(m);
b[0]=1e18;
for(int i=1;i<=n;i++){
read(a[i]);
g[i]=1e12;
}
for(int i=1;i<=m;i++){
for(int j=1;j<=cnt;j++){
tr[j].ch[0]=tr[j].ch[1]=tr[j].mn=0;
}
for(int j=1;j<=n;j++){
deque<int>().swap(q[j]);
q[j].pb(j-1);
}
cnt=top=0;
for(int j=1;j<=n;j++){
while(top&&a[stk[top]]<=a[j]){
merge(stk[top],j);
top--;
}
k[j]=a[j];
b[j]=ask(k[j],j);
update(rt[j],rt[stk[top]],1,n,j);
f[j]=query(rt[j],1,n,j);
stk[++top]=j;
}
for(int j=1;j<=n;j++){
g[j]=f[j];
}
}
write_endl(f[n]);
}
signed main(){
#ifndef ONLINE_JUDGE
freopen("1.in","r",stdin);
freopen("1.out","w",stdout);
#endif
int t=1;
while(t--){
solve();
}
return 0;
}