5.4.3 斜率优化dp
现在可能理解不是太深刻,但是我还是想把我不深刻的理解记录下来,以防忘记
概念
有一类DP状态方程:
\(\displaystyle dp[i] = min\{dp[j]-a[i]d[j]\},\ 0\leq j<i,d[j]\leq d[j+1],a[i]\leq a[i+1]\)
它的特征是存在一个既有\(\displaystyle i\)又有\(\displaystyle j\)的项 \(\displaystyle a[i],\ d[j]\) ,编程时,如果简单地对外层\(\displaystyle i\)和内层\(\displaystyle j\)循环,复杂度为\(\displaystyle O(n^{2})\)
这里能用单调队列优化吗?单调队列所处理的策略,要求只能与内层有关,与外层无关,但是这个状态方程无法简单地得到只与\(\displaystyle j\)有关的部分
用斜率(凸壳)模型,能够将方程转化,得到一个只与\(\displaystyle j\)有关的部分,即“斜率”从而能够使用单调队列优化。这个算法称为斜率优化/凸壳优化(Convex Hull Trick),总时间复杂度为(n),斜率优化的核心技术是斜率(凸壳)模型和单调队列。
例题
结合例题理解
例题:Problem - 3507
Zero 想要打印一篇有 N 个单词的文章,每个单词 i 的打印成本为 \(\displaystyle C_{i}\)。此外,Zero 知道在一行中打印 k 个单词的成本为
求最小成本,\(\displaystyle 0\leq n\leq 5e5,0\leq M\leq 1e 3\)
我们很容易能够得到转移方程:
这里\(\displaystyle s_{i}\)表示从开头到\(\displaystyle i\)位置的前缀和。
把它转化一下
因为我们枚举的是\(\displaystyle i\)所以\(\displaystyle f_{i}\)和\(\displaystyle s_{i}\)是已知的,所以再变型
因为我们要求的是最小的\(\displaystyle f_{i}\),所以如果用这个式子的话是求上凸包(我不太熟悉),我们给左右两边同乘-1
这里的有关 \(j\) 的项我们都是不知道的,所以我们可以将上面的式子看做 \(y = kx + b\) ,将决策点看做平面上坐标为 \((x,y)\) 的点
因为我们要使 \(f_i\) 最小,所以我们可以将绿线向上移动,因为斜率不变,所以第一个碰到的点可以使 \(b \ (f_i - s_i^2 - M)\) 最小,从而使 \(f_i\) 最小
现在我们的问题转化成了维护这些点,从而使得我们能够快速找到对应的点,这就用到了我们的单调队列
我们维护一个相邻两个点间斜率单调递增的队列,注意,队列中的元素个数要 >= 1
- 将 \(i - 1\) 加入队列,如果是上图中的的那种情况
slope(t - 1, t) >= slope(t, i - 1)
那么我们将队尾元素弹出 - 将队头过时元素弹出,
slope(h, h + 1) <= 2 * s[i]
,我们将队头弹出 - 更新当前的 dp 值
code
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
using namespace std;
typedef long long ll;
const int N = 5e5 + 10;
int n, m, q[N];
ll s[N], f[N];
//计算斜率,(y2 - y1) / (x2 - x1)
double slope (int i, int j) {
return (double)(f[j] + s[j] * s[j] - f[i] - s[i] * s[i])
/ (s[j] == s[i] ? 1e-9 : s[j] - s[i]);
}
int main () {
while (~scanf ("%d%d", &n, &m)) {
ll x;
for (int i = 1; i <= n; i++)
scanf ("%lld", &x), s[i] = s[i - 1] + x;
int h = 1, t = 0;
for (int i = 1; i <= n; i++) {
//将 i - 1 插入队尾,h < t 保证队列中至少有 1 个元素
while (h < t && slope (q[t], i - 1) <= slope(q[t - 1], q[t]))
t--;
q[++t] = i - 1;
//弹出过时队头
while (h < t && slope(q[h], q[h + 1]) <= 2 * s[i]) h++;
//更新当前 dp 值
int j = q[h];
f[i] = f[j] + (s[i] - s[j]) * (s[i] - s[j]) + m;
}
printf ("%lld\n", f[n]);
}
return 0;
}