斜率优化

算法描述:

对于一种 DP:

\[f_{i}=\min / \max (f_{j}+cost(j)+cost(i)+F(i)F(j)) \]

那么可以使用斜率优化,将它改写成一个一次函数的形式 \(y=kx+b\),即:

\[f_{j}+cost(j)=F(i)F(j)+(f_{i}-cost(i)) \]

\(\min\) 为例,要让 \(f_{i}\) 最小,即让 \(f_{i}-cost(i)\) 最小,也就是让经过点 \((F(j),f_{j}+cost(j))\) 的斜率为 \(F(i)\) 的直线的截距最小。

考虑三个点 \(A,B,C\) 的坐标分别为 \((x_{1},y_{1}),(x_{2},y_{2}),(x_{3},y_{3})\) 满足 \(x_{1} < x_{2} < x_{3}\)

\(A \to B\) 的斜率 \(k_{1}\)\(B \to C\) 的斜率 \(k_{2}\)。如果 \(k_{1} \ge k_{2}\)\(B\) 这个点一定不优。

证明:

\(k \le k_{1}\) 时,\(A\) 不劣于 \(B\);当 \(k \ge k_{2}\) 时,\(C\) 不劣于 \(B\)

由于 \(k_{1} \ge k_{2}\),故 \(C\) 一定不优。

剔除掉这样的情况,所有点形成了一个下凸包(斜率单调递增)。

假设我们现在得到了这个下凸包,考虑如何算答案。

对于凸包中相邻的两个点,前一个与后一个相比不劣,当且仅当两个点的斜率大于等于 \(k\)。由于下凸包中斜率单调递增,可以直接二分。

算法实现与细节:

  • \(x\) 单增,\(k\) 有单调性:使用单调队列维护即可。时间复杂度 \(O(n)\)

  • \(x\) 单增:单调栈维护凸包,二分求答案。时间复杂度 \(O(n \log n)\)

  • 其余情况:截距最小,也就是 \(y_{i}-kx_{i}\) 最小,再看成一个直线 \(y=-kx_{i}+y_{i}\)。李超线段树即可。

斜率优化有两个代码的小细节:

  • 二分找大于等于还是大于?实际上都可以,因为此时两个点截距相同。
  • 维护凸包时,踢不踢斜率相同的?要踢。考虑三个点的 \(x\) 坐标相同的情况。如果不踢的话很可能维护出一个不是凸包的东西。
  • 计算斜率时,把除法换成乘法。这样既避免的精度误差的问题,又不会导致分母为 \(0\)

当然,如果 DP 的转移方程是 \(\max\),即要求截距大,维护上凸包即可。二分的时候就是找第一个小于等于 \(k\) 的了。

例题:

[NOI2014] 购票

题意:

今年夏天,NOI 在 SZ 市迎来了她三十周岁的生日。来自全国 \(n\) 个城市的 OIer 们都会从各地出发,到 SZ 市参加这次盛会。

全国的城市构成了一棵以 SZ 市为根的有根树,每个城市与它的父亲用道路连接。为了方便起见,我们将全国的 \(n\) 个城市用 \(1\sim n\) 的整数编号。其中 SZ 市的编号为 \(1\)。对于除 SZ 市之外的任意一个城市 \(v\),我们给出了它在这棵树上的父亲城市 \(f_v\) 以及到父亲城市道路的长度 \(s_v\)

从城市 \(v\) 前往 SZ 市的方法为:选择城市 \(v\) 的一个祖先 \(a\),支付购票的费用,乘坐交通工具到达 \(a\)。再选择城市 \(a\) 的一个祖先 \(b\),支付费用并到达 \(b\)。以此类推,直至到达 SZ 市。

对于任意一个城市 \(v\),我们会给出一个交通工具的距离限制 \(l_v\)。对于城市 \(v\) 的祖先 A,只有当它们之间所有道路的总长度不超过 \(l_v\) 时,从城市 \(v\) 才可以通过一次购票到达城市 A,否则不能通过一次购票到达。

对于每个城市 \(v\),我们还会给出两个非负整数 \(p_v,q_v\) 作为票价参数。若城市 \(v\) 到城市 A 所有道路的总长度为 \(d\),那么从城市 \(v\) 到城市 A 购买的票价为 \(dp_v+q_v\)

每个城市的 OIer 都希望自己到达 SZ 市时,用于购票的总资金最少。你的任务就是,告诉每个城市的 OIer 他们所花的最少资金是多少。

分析:

容易得到一个转移:

\[f_{x}=f_{y}+(dis_{x}-dis_{y})p_{x}+q_{x} \]

其中 \(y\) 表示 \(x\) 的祖先。但 \(y\) 有个限制条件,树剖+线段树套李超树可以做到 \(O(n \log^3 n)\) 的时间复杂度和 \(O(n \log n)\) 的空间复杂度。

这里有个小套路,记 \(dfn_{x}\) 表示 \(x\) 离开 DFS 的顺序。记 \(y\) 表示最远的祖先满足这个限制,直接在 \([dfn_{x},dfn_{y}]\) 的树套李超树上查询即可,不在链上的一定没有进入过。时间复杂度 \(O(n \log^2 n)\)

#include<bits/stdc++.h>
#define int long long
#define N 4000005
using namespace std;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;
int read() {
    int p = 0, flg = 1;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-') flg = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        p = p * 10 + c - '0';
        c = getchar();
    }
    return p * flg;
}
void print(int x) {
	if(x < 0) putchar('-'), x = -x;
	if(x > 9) print(x / 10);
	putchar('0' + x % 10);
}
int n, tt, tot, top;
int f[N], dis[N], stk[N], rt[N], k[N], b[N], s[N], p[N], q[N], l[N], dfn[N], cnt; //dfn[i]表示i的出栈序
vector<int>G[N];
void dfs1(int x, int fa) { dis[x] = dis[fa] + s[x]; for(auto y : G[x]) dfs1(y, x); dfn[x] = ++cnt; }
int Get(int id, int x) { return k[id] * x + b[id]; }
struct node { int ls, rs, v; }t[N];
int ask(int u, int L, int R, int Num) {
    int Min = 1e18, mid = (L + R) / 2;
    if(t[u].v) Min = Get(t[u].v, Num);
    if(L == R) return Min;
    if(Num <= mid) Min = min(Min, ask(t[u].ls, L, mid, Num));
    else Min = min(Min, ask(t[u].rs, mid + 1, R, Num));
    return Min;
}
void add(int &u, int L, int R, int id) {
    if(!u) u = ++tot; 
    if(!t[u].v) {
        t[u].v = id;
        return;
    }
    if(L == R) return; int mid = (L + R) / 2;
    if(Get(id, mid) < Get(t[u].v, mid)) swap(id, t[u].v);
    if(Get(id, L) < Get(t[u].v, L)) add(t[u].ls, L, mid, id);
    if(Get(id, R) < Get(t[u].v, R)) add(t[u].rs, mid + 1, R, id);
}
void update(int u, int L, int R, int x) {
    add(rt[u], 0, 1e6, x);
    if(L == R) return;
    int mid = (L + R) / 2;
    if(x <= mid) update(u * 2, L, mid, x);
    else update(u * 2 + 1, mid + 1, R, x);
}
int query(int u, int L, int R, int l, int r, int Num) {
    if(l <= L && R <= r) return ask(rt[u], 0, 1e6, Num);
    if(R < l || r < L) return 1e18;
    int mid = (L + R) / 2;
    return min(query(u * 2, L, mid, l, r, Num), query(u * 2 + 1, mid + 1, R, l, r, Num));
}
void dfs2(int x) {
    stk[++top] = x;
    if(x != 1) {
        int L = 1, R = top - 1, mid, Get = -1;
        while(L <= R) {
            mid = (L + R) / 2;
            if(dis[x] - dis[mid] <= l[x]) {
                Get = mid;
                R = mid - 1;
            }
            else L = mid + 1;
        }
        Get = stk[Get];
        f[x] = query(1, 1, n, dfn[x], dfn[Get], p[x]) + dis[x] * p[x] + q[x];
    }
    k[dfn[x]] = -dis[x]; b[dfn[x]] = f[x];
    update(1, 1, n, dfn[x]);
    for(auto y : G[x]) dfs2(y); top--;
}
signed main() {
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    n = read(), tt = read();
    for(int i = 2, fa; i <= n; i++) {
        fa = read(), s[i] = read(), p[i] = read(), q[i] = read(), l[i] = read();
        G[fa].push_back(i);
    } 
    dfs1(1, 0); dfs2(1);
    for(int i = 2; i <= n; i++) {
        print(f[i]); puts("");
    }
    return 0;
}
posted @ 2024-09-18 19:48  小超手123  阅读(46)  评论(0)    收藏  举报