Loading

WQS二分讲解

首先思考这样一个题目:

给定一个序列 \(A\) ,包含 \(n\) 个整数,求刚好取 \(k\) 个互不相交的子段,使得总和最大,输出最大总和。

我们可以用 \(O(n^2)\)​ 的dp来做,\(f_{i,j}\)​ 表示前 \(i\)​ 个元素取了 \(j\)​ 段的最大总和。


WQS二分

太慢了?

WQS二分可以用 \(O(n \log n)\) 的时间复杂度完成这道题。

WQS二分的精髓在于通过给操作附一个”代价值“,使得”恰好取 \(k\) 个“的问题,转化为”选任意个“的问题

引入一个代价 \(c\) 表示每取一段会让最后的总和减去 \(c\)

假如说前 \(i\) 个元素取了 \(j\) 段,那么就要减去 \(j \times c\)

为什么要这么做?因为我们想知道能不能找到一个代价 \(c\) ,使得取 \(k\)​ 段是最优秀的,这样少的那一维就没有影响了。

我们发现:

  • \(c\) 很大的时候,我们要尽量少地选择段数。(大得离谱的时候我们肯定一段都不选)
  • \(c\)​ 很小的时候(甚至是负代价),我们要尽量多地选择段数。(小得离谱的时候我们肯定取 \(n\) 段)

这个东西是单调的,我们就可以二分找到一个 \(c\) ,使得最优秀的是取 \(k\) 段。

那我们每次就check:在当前的 \(c\) 的情况下,我们跑出来的最大总和用了几段。

  • 如果段数大于 \(k\)\(c\) 就往大的方向找。
  • 如果段数小于 \(k\)\(c\) 就往小的方向找。

WQS二分就讲完了。。。。 (后边还有几何理解,可以不用看)

代码:

#include<bits/stdc++.h>
#define endl "\n"
#define int long long
using namespace std;

const int N = 2e5 + 10;
const int INF = 1e16; 

int n, k, a[N];

struct node {
    int val;
    int cnt;
    // 价值不同时取价值大的;价值相同时,强制取段数多的(保证单调性)
    bool operator<(const node& o) const {
        if (val != o.val) return val < o.val;
        return cnt < o.cnt; 
    }
} dp[N][2];

// 无限制条件下的 DP,c 是代价
node check(int c) {
    dp[0][0] = {0, 0};
    dp[0][1] = {-INF, 0}; // 第 0 个数不可能存在于一段中
    
    for (int i = 1; i <= n; i++) {
        // 1. 第 i 个数不在段里,继承前面的最大值
        dp[i][0] = max(dp[i-1][0], dp[i-1][1]);
        
        // 2. 第 i 个数在段里
        // 选项 A:作为新的一段的起点(基于前面的全局最优,减去惩罚值,段数+1)
        node pre = max(dp[i-1][0], dp[i-1][1]);
        node now = {pre.val + a[i] - c, pre.cnt + 1};
        
        // 选项 B:紧接着上一段继续延长(不减惩罚值,段数不变)
        node ex = {dp[i-1][1].val + a[i], dp[i-1][1].cnt};
        
        // 取两者的最优解
        dp[i][1] = max(now, ex);
    }
    
    // 返回全局的最优状态
    return max(dp[n][0], dp[n][1]);
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    
    cin >> n >> k;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    
    // 二分惩罚值 c
    // 范围要根据数组元素的值域来定,这里开到 1e12 绝对够用
    int l = -1e12, r = 1e12, ans = 0;
    
    while (l <= r) {
        int mid = (l+r) >>1;
        node res = check(mid);
        
        if (res.cnt >= k) {
            // 选的段数太多(或刚好等于 k),说明惩罚值太小了!
            // 记录答案(把扣掉的 k 次惩罚值补回来),并尝试加大惩罚值
            ans = res.val + k * mid; 
            l = mid + 1;
        } else {
            // 选的段数不够 k 个,说明惩罚得太狠了,大家都不敢选
            // 减轻惩罚值
            r = mid - 1;
        }
    }
    
    cout << ans << endl;
    
    return 0;
}

几何理解

  • 我们要最大化 \(f_x=g_x-x\times c\),这里 \(f_x\) 表示:取 \(x\) 段并把代价考虑在内的最大总和。
  • 简单变形:\(g_x=x\times c+f_x\) ,这是一条斜率为 \(c\) 的直线,我们要最大化截距 \(f_x\)
  • 问题就转换成了用斜率为 \(c\) 的直线去经过每一个坐标点 \((x,g_x)\) ,看截距最大是多少。
  • 所以 WQS 二分的几何意义就是:拿一条斜率为 c 的直线去切这个凸包,截距最大时,切点对应的横坐标就是我们当前选的个数。二分斜率,直到切点的横坐标刚好落在 k 上
posted @ 2026-03-12 11:54  TommyJin  阅读(6)  评论(0)    收藏  举报