cdq实现斜率优化代码(带注释)
因为博主蒟蒻的原因, 在第一次写cdq实现斜率优化用了快一天的时间才把代码改对, 所以写一篇博客记录一下写法以及算法的基本实现思路。
#include <cstdio>
#include <algorithm>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f3f
#define MAXN 100000
struct cdq_node {//斜率优化的基本形式:y = kx + b
int loc;
long long x, k;
long long constant;//即是上式中的b
long long y;
}s_cdq[MAXN + 5], tem[MAXN + 5];
long long dp[MAXN + 5];//转移的dp值
int q[MAXN + 5];
bool cmp (cdq_node a, cdq_node b) {//最开始按照k排序, 以达到O(nlogn)时间
return a.k < b.k;
}
double slope (cdq_node a, cdq_node b) {//斜率, 由于转为乘法会出现因正负变不等号方向的问题, 而c++的高精还是挺准的, 所以还是使用除法
return 1.0 * (a.y - b.y) / (a.x - b.x);
}
long long level (cdq_node a, cdq_node b) {//算出a对b的贡献与a有关的部分
return a.y - a.x * b.k;
}
void transfer (int l, int mid, int r) {//将左区间的状态转移到右区间中
int h = 1, t = 0;
//以左区间构建凸包
for (int e = l; e <= mid; e ++) {
if (e != l && s_cdq[e].x == tem[e - 1].x) {
continue;
}
while (t > h && slope (s_cdq[q[t]], s_cdq[e]) <= slope (s_cdq[q[t - 1]], s_cdq[q[t]])) {
t --;
}
q[++ t] = e;
}
//向右区间进行转移
for (int e = mid + 1; e <= r; e ++) {
while (h < t && level (s_cdq[q[h]], s_cdq[e]) >= level (s_cdq[q[h + 1]], s_cdq[e])) {
h ++;
}
if (level (s_cdq[q[h]], s_cdq[e]) + s_cdq[e].constant < dp[s_cdq[e].loc]) {
dp[s_cdq[e].loc] = level (s_cdq[q[h]], s_cdq[e]) + s_cdq[e].constant;
}
}
}
//将按照k排序的两个区间裂成位置在mid左和右的两部分, 因为要保证右区间按照k递增, 而左区间会先进行递归, 返回时归并排序就可以使其顺序正确, 所以用分裂的方法, 使未使用超过O(1)的数据结构时总时间为O(nlogn)
void split (int l, int r, int mid) {
int i = l, j = mid + 1;
for (int e = l; e <= r; e ++) {
if (s_cdq[e].loc <= mid) {
tem[i ++] = s_cdq[e];
}
else {
tem[j ++] = s_cdq[e];
}
}
for (int e = l; e <= r; e ++) {
s_cdq[e] = tem[e];
}
}
//cdq中的归并, 无需过多解释
void merge (int l, int mid, int r) {
int i = l, j = mid + 1;
int k = l;
while (i <= mid && j <= r) {
if (s_cdq[i].x < s_cdq[j].x || (s_cdq[i].x == s_cdq[j].x && s_cdq[i].y < s_cdq[j].y)) {
tem[k ++] = s_cdq[i ++];
}
else {
tem[k ++] = s_cdq[j ++];
}
}
while (i <= mid) {
tem[k ++] = s_cdq[i ++];
}
while (j <= r) {
tem[k ++] = s_cdq[j ++];
}
for (int e = l; e <= r; e ++) {
s_cdq[e] = tem[e];
}
}
//整体拼接起来
void cdq (int l, int r) {
if (l == r) {
...//这里需要计算x, y 中与dp有关的项的值,此时这两项并未参与到转移, 而dp值由于不会从右往左转移, 所以此时也是求出准确答案的了
return ;
}
int mid = (l + r) >> 1;
//按照顺序执行操作
split (l, r, mid);//先归并左边的区间, 因为右边的区间并不能向左转移, 所以先向下递归, 进行区间内的转移
cdq (l, mid);//由于左区间会影响右区间, 所以转移左区间到右区间后再进行右区间内的转移
transfer (l, mid, r);
cdq (mid + 1, r);
merge (l, mid, r);
}
int main () {
...//输入等一系列初始化。
dp[...] = ...;//将初始位置附上初值
for (int i = /*1或0, 视情况而定*/; i <= n; i ++) {
//将与dp无关的量先求出
s_cdq[i].loc = i;
s_cdq[i].constant = ...;
s_cdq[i].k = ...;
s_cdq[i].x = ...;
s_cdq[i].y = ...;
}
sort (s_cdq + 1, s_cdq + 1 + n, cmp);//按照k排序后在进行cdq
cdq (1, n);
...//之后对dp值的处理与输出
}