WQS二分讲解
首先思考这样一个题目:
给定一个序列 \(A\) ,包含 \(n\) 个整数,求刚好取 \(k\) 个互不相交的子段,使得总和最大,输出最大总和。
我们可以用 \(O(n^2)\) 的dp来做,\(f_{i,j}\) 表示前 \(i\) 个元素取了 \(j\) 段的最大总和。
WQS二分
太慢了?
WQS二分可以用 \(O(n \log n)\) 的时间复杂度完成这道题。
WQS二分的精髓在于通过给操作附一个”代价值“,使得”恰好取 \(k\) 个“的问题,转化为”选任意个“的问题。
引入一个代价 \(c\) 表示每取一段会让最后的总和减去 \(c\) 。
假如说前 \(i\) 个元素取了 \(j\) 段,那么就要减去 \(j \times c\) 。
为什么要这么做?因为我们想知道能不能找到一个代价 \(c\) ,使得取 \(k\) 段是最优秀的,这样少的那一维就没有影响了。
我们发现:
- \(c\) 很大的时候,我们要尽量少地选择段数。(大得离谱的时候我们肯定一段都不选)
- \(c\) 很小的时候(甚至是负代价),我们要尽量多地选择段数。(小得离谱的时候我们肯定取 \(n\) 段)
这个东西是单调的,我们就可以二分找到一个 \(c\) ,使得最优秀的是取 \(k\) 段。
那我们每次就check:在当前的 \(c\) 的情况下,我们跑出来的最大总和用了几段。
- 如果段数大于 \(k\) ,\(c\) 就往大的方向找。
- 如果段数小于 \(k\) ,\(c\) 就往小的方向找。
WQS二分就讲完了。。。。 (后边还有几何理解,可以不用看)
代码:
#include<bits/stdc++.h>
#define endl "\n"
#define int long long
using namespace std;
const int N = 2e5 + 10;
const int INF = 1e16;
int n, k, a[N];
struct node {
int val;
int cnt;
// 价值不同时取价值大的;价值相同时,强制取段数多的(保证单调性)
bool operator<(const node& o) const {
if (val != o.val) return val < o.val;
return cnt < o.cnt;
}
} dp[N][2];
// 无限制条件下的 DP,c 是代价
node check(int c) {
dp[0][0] = {0, 0};
dp[0][1] = {-INF, 0}; // 第 0 个数不可能存在于一段中
for (int i = 1; i <= n; i++) {
// 1. 第 i 个数不在段里,继承前面的最大值
dp[i][0] = max(dp[i-1][0], dp[i-1][1]);
// 2. 第 i 个数在段里
// 选项 A:作为新的一段的起点(基于前面的全局最优,减去惩罚值,段数+1)
node pre = max(dp[i-1][0], dp[i-1][1]);
node now = {pre.val + a[i] - c, pre.cnt + 1};
// 选项 B:紧接着上一段继续延长(不减惩罚值,段数不变)
node ex = {dp[i-1][1].val + a[i], dp[i-1][1].cnt};
// 取两者的最优解
dp[i][1] = max(now, ex);
}
// 返回全局的最优状态
return max(dp[n][0], dp[n][1]);
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> k;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
// 二分惩罚值 c
// 范围要根据数组元素的值域来定,这里开到 1e12 绝对够用
int l = -1e12, r = 1e12, ans = 0;
while (l <= r) {
int mid = (l+r) >>1;
node res = check(mid);
if (res.cnt >= k) {
// 选的段数太多(或刚好等于 k),说明惩罚值太小了!
// 记录答案(把扣掉的 k 次惩罚值补回来),并尝试加大惩罚值
ans = res.val + k * mid;
l = mid + 1;
} else {
// 选的段数不够 k 个,说明惩罚得太狠了,大家都不敢选
// 减轻惩罚值
r = mid - 1;
}
}
cout << ans << endl;
return 0;
}
几何理解
- 我们要最大化 \(f_x=g_x-x\times c\),这里 \(f_x\) 表示:取 \(x\) 段并把代价考虑在内的最大总和。
- 简单变形:\(g_x=x\times c+f_x\) ,这是一条斜率为 \(c\) 的直线,我们要最大化截距 \(f_x\) 。
- 问题就转换成了用斜率为 \(c\) 的直线去经过每一个坐标点 \((x,g_x)\) ,看截距最大是多少。
- 所以 WQS 二分的几何意义就是:拿一条斜率为 c 的直线去切这个凸包,截距最大时,切点对应的横坐标就是我们当前选的个数。二分斜率,直到切点的横坐标刚好落在 k 上。

浙公网安备 33010602011771号