斜率优化
算法描述:
对于一种 DP:
那么可以使用斜率优化,将它改写成一个一次函数的形式 \(y=kx+b\),即:
以 \(\min\) 为例,要让 \(f_{i}\) 最小,即让 \(f_{i}-cost(i)\) 最小,也就是让经过点 \((F(j),f_{j}+cost(j))\) 的斜率为 \(F(i)\) 的直线的截距最小。
考虑三个点 \(A,B,C\) 的坐标分别为 \((x_{1},y_{1}),(x_{2},y_{2}),(x_{3},y_{3})\) 满足 \(x_{1} < x_{2} < x_{3}\)。
记 \(A \to B\) 的斜率 \(k_{1}\),\(B \to C\) 的斜率 \(k_{2}\)。如果 \(k_{1} \ge k_{2}\),\(B\) 这个点一定不优。
证明:
当 \(k \le k_{1}\) 时,\(A\) 不劣于 \(B\);当 \(k \ge k_{2}\) 时,\(C\) 不劣于 \(B\)。
由于 \(k_{1} \ge k_{2}\),故 \(C\) 一定不优。
剔除掉这样的情况,所有点形成了一个下凸包(斜率单调递增)。
假设我们现在得到了这个下凸包,考虑如何算答案。
对于凸包中相邻的两个点,前一个与后一个相比不劣,当且仅当两个点的斜率大于等于 \(k\)。由于下凸包中斜率单调递增,可以直接二分。
算法实现与细节:
-
\(x\) 单增,\(k\) 有单调性:使用单调队列维护即可。时间复杂度 \(O(n)\)。
-
\(x\) 单增:单调栈维护凸包,二分求答案。时间复杂度 \(O(n \log n)\)。
-
其余情况:截距最小,也就是 \(y_{i}-kx_{i}\) 最小,再看成一个直线 \(y=-kx_{i}+y_{i}\)。李超线段树即可。
斜率优化有两个代码的小细节:
- 二分找大于等于还是大于?实际上都可以,因为此时两个点截距相同。
- 维护凸包时,踢不踢斜率相同的?要踢。考虑三个点的 \(x\) 坐标相同的情况。如果不踢的话很可能维护出一个不是凸包的东西。
- 计算斜率时,把除法换成乘法。这样既避免的精度误差的问题,又不会导致分母为 \(0\)。
当然,如果 DP 的转移方程是 \(\max\),即要求截距大,维护上凸包即可。二分的时候就是找第一个小于等于 \(k\) 的了。
例题:
[NOI2014] 购票
题意:
今年夏天,NOI 在 SZ 市迎来了她三十周岁的生日。来自全国 \(n\) 个城市的 OIer 们都会从各地出发,到 SZ 市参加这次盛会。
全国的城市构成了一棵以 SZ 市为根的有根树,每个城市与它的父亲用道路连接。为了方便起见,我们将全国的 \(n\) 个城市用 \(1\sim n\) 的整数编号。其中 SZ 市的编号为 \(1\)。对于除 SZ 市之外的任意一个城市 \(v\),我们给出了它在这棵树上的父亲城市 \(f_v\) 以及到父亲城市道路的长度 \(s_v\)。
从城市 \(v\) 前往 SZ 市的方法为:选择城市 \(v\) 的一个祖先 \(a\),支付购票的费用,乘坐交通工具到达 \(a\)。再选择城市 \(a\) 的一个祖先 \(b\),支付费用并到达 \(b\)。以此类推,直至到达 SZ 市。
对于任意一个城市 \(v\),我们会给出一个交通工具的距离限制 \(l_v\)。对于城市 \(v\) 的祖先 A,只有当它们之间所有道路的总长度不超过 \(l_v\) 时,从城市 \(v\) 才可以通过一次购票到达城市 A,否则不能通过一次购票到达。
对于每个城市 \(v\),我们还会给出两个非负整数 \(p_v,q_v\) 作为票价参数。若城市 \(v\) 到城市 A 所有道路的总长度为 \(d\),那么从城市 \(v\) 到城市 A 购买的票价为 \(dp_v+q_v\)。
每个城市的 OIer 都希望自己到达 SZ 市时,用于购票的总资金最少。你的任务就是,告诉每个城市的 OIer 他们所花的最少资金是多少。
分析:
容易得到一个转移:
其中 \(y\) 表示 \(x\) 的祖先。但 \(y\) 有个限制条件,树剖+线段树套李超树可以做到 \(O(n \log^3 n)\) 的时间复杂度和 \(O(n \log n)\) 的空间复杂度。
这里有个小套路,记 \(dfn_{x}\) 表示 \(x\) 离开 DFS 的顺序。记 \(y\) 表示最远的祖先满足这个限制,直接在 \([dfn_{x},dfn_{y}]\) 的树套李超树上查询即可,不在链上的一定没有进入过。时间复杂度 \(O(n \log^2 n)\)。
#include<bits/stdc++.h>
#define int long long
#define N 4000005
using namespace std;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;
int read() {
int p = 0, flg = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') flg = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
p = p * 10 + c - '0';
c = getchar();
}
return p * flg;
}
void print(int x) {
if(x < 0) putchar('-'), x = -x;
if(x > 9) print(x / 10);
putchar('0' + x % 10);
}
int n, tt, tot, top;
int f[N], dis[N], stk[N], rt[N], k[N], b[N], s[N], p[N], q[N], l[N], dfn[N], cnt; //dfn[i]表示i的出栈序
vector<int>G[N];
void dfs1(int x, int fa) { dis[x] = dis[fa] + s[x]; for(auto y : G[x]) dfs1(y, x); dfn[x] = ++cnt; }
int Get(int id, int x) { return k[id] * x + b[id]; }
struct node { int ls, rs, v; }t[N];
int ask(int u, int L, int R, int Num) {
int Min = 1e18, mid = (L + R) / 2;
if(t[u].v) Min = Get(t[u].v, Num);
if(L == R) return Min;
if(Num <= mid) Min = min(Min, ask(t[u].ls, L, mid, Num));
else Min = min(Min, ask(t[u].rs, mid + 1, R, Num));
return Min;
}
void add(int &u, int L, int R, int id) {
if(!u) u = ++tot;
if(!t[u].v) {
t[u].v = id;
return;
}
if(L == R) return; int mid = (L + R) / 2;
if(Get(id, mid) < Get(t[u].v, mid)) swap(id, t[u].v);
if(Get(id, L) < Get(t[u].v, L)) add(t[u].ls, L, mid, id);
if(Get(id, R) < Get(t[u].v, R)) add(t[u].rs, mid + 1, R, id);
}
void update(int u, int L, int R, int x) {
add(rt[u], 0, 1e6, x);
if(L == R) return;
int mid = (L + R) / 2;
if(x <= mid) update(u * 2, L, mid, x);
else update(u * 2 + 1, mid + 1, R, x);
}
int query(int u, int L, int R, int l, int r, int Num) {
if(l <= L && R <= r) return ask(rt[u], 0, 1e6, Num);
if(R < l || r < L) return 1e18;
int mid = (L + R) / 2;
return min(query(u * 2, L, mid, l, r, Num), query(u * 2 + 1, mid + 1, R, l, r, Num));
}
void dfs2(int x) {
stk[++top] = x;
if(x != 1) {
int L = 1, R = top - 1, mid, Get = -1;
while(L <= R) {
mid = (L + R) / 2;
if(dis[x] - dis[mid] <= l[x]) {
Get = mid;
R = mid - 1;
}
else L = mid + 1;
}
Get = stk[Get];
f[x] = query(1, 1, n, dfn[x], dfn[Get], p[x]) + dis[x] * p[x] + q[x];
}
k[dfn[x]] = -dis[x]; b[dfn[x]] = f[x];
update(1, 1, n, dfn[x]);
for(auto y : G[x]) dfs2(y); top--;
}
signed main() {
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
n = read(), tt = read();
for(int i = 2, fa; i <= n; i++) {
fa = read(), s[i] = read(), p[i] = read(), q[i] = read(), l[i] = read();
G[fa].push_back(i);
}
dfs1(1, 0); dfs2(1);
for(int i = 2; i <= n; i++) {
print(f[i]); puts("");
}
return 0;
}
浙公网安备 33010602011771号