题解:CF2183H Minimise Cost

感觉其实做法都能想到但是证明有点有点啊,在场上感觉不太好证这个东西。顺便说一句,一开始看题解用 \(g\) 去证明凸性看快了以为那个证明是完全对的,然后写了一个过了,具体到做法部分会说是啥,但是后面写题解的时候发现这个证明貌似很错,但是拍了随机数据很对,不会造数据 hack,求求佬解释一下或者 hack。

题意:给出一个序列 \(a\),要求原序列划分成 \(k\) 个子序列,定义一个长为 \(n\) 的序列 \(s\) 的权值为 \(n\sum s\),要最小化所有子序列的代价和。\(n\le 2\times 10^5,|a_i| \le 10^9\)

做法:

首先肯定是把序列排序后就是选连续的区间,然后就变成一个序列划分问题。有转移式:

\[dp_{i} = \min_{j=0}^{i-1}(s_i-s_j)(i-j)+dp_j \]

感觉一下这个权值函数很凸,所以答案也是凸的,直接上 wqs + 二分队列做到 \(O(n\log n\log V)\) 就做完了,吗?

让我们来细推一下这个四边形不等式合法条件:

\[w(l - 1,r) + w(l,r+1) \le w(l - 1,r+1)+w(l,r) \]

\[0\le a_{l-1}+a_{r+1} \]

然后让我们看看数据范围,欸,竟然 \(a\) 有负数!那我不是炸了,然后就可以火大地遗憾离场了。

但是这种问题感觉一下如果没有单调性完全做不了啊,并且这个东西对于正数是很对的,所以感觉一下这个凸包应该几乎是凸的只是有些例外,比如对于负数怎么特殊处理一下就是凸的了。

我们先不考虑单调性的事情,单独考虑这个问题,那么我们发现:

  • 对于正数要尽量多划段。

  • 对于负数尽量少划段。

欸,这两个东西是可以同时存在的,也就是说,我们应该先把负数划成一段,然后对于正数内部划段。同时我们发现,如果一定要在负数内部划段,我肯定是优先让最大的若干个负数单独一段,剩余的一大段。

讨论一下,如果段数 $\le $ 非负数个数,那么我猜有凸性,否则可以直接计算答案。

考虑怎么证明凸性,掏出 oi-wiki,翻一下怎么证明凸性,发现有一个证明方法:

如果 \(g(k-1)+g(k+1)\ge 2g(k)\),这里 \(g\) 是划分成 \(k\) 个段的代价。正确性显然。

然后官方题解就直接把 oi-wiki 里这个东西证明拍下来,大概是交换一下区间讨论贡献,但是依赖于四边形不等式,告诉我这个东西是对的就有凸性了???也就是说,对于 \(1\le k \le\) 正数个数的部分有凸性。有凸性就直接 wqs 就可以优化掉选择 \(k\) 的那一部分变成 \(\log\)

我一开始觉得这个东西是对的,但是后面我想了想用这个证明去对于这样一组数据去做:

n = 5, k = 2
a = [-100, 2, 3, 4, 5]

然后发现对于 \(g(2)\) 并不能按照原本的证明方式证明,原因是因为还是用到了负数段的和这种东西,但是拍了一下这东西就是没锅,不是很懂为什么,好像可以很魔怔的认为是前半段和后半段都是凸的,但是有个点不一定是凸的。

说说这东西的做法,考虑从后往前划分段,暴力枚举最后一个段在哪里,只要完全包含负数段即可。

但是接下来这个做法就是完全对的了。我们发现负数段的问题还是没咋解决,我们考虑一个贪:如果对于正数最小值,我们把它划到负数这边贡献更优,那么我们就划到负数这边合并起来成一个大数。形式化地,这个优的意思是这样一个柿子。

定义 \(s_k\) 代表我目前负数段的大小,长度为 \(k\),这个数为 \(a\),那么要满足:

\[(s_k+a)(k+1)\le s_kk+a \]

这样称为优。

解释一下,左边是划入左边段的贡献,右边是因为,假设不划入,那么 \(a\) 至少要产生单独一段也就是 \(a\) 的贡献。

发现一个很神奇的事情,如果我不把这些负数段合并后中间的位置当成合法的决策点,就满足四边形不等式了!可以自己推一推,这个我推了确实是没毛病的。

那么我们只要挖掉这些决策点,并且对于剩下的没有被合并的点的个数 \(x\)\(k\le x\) 时决策单调性,否则直接计算,这样就对了,感觉这个方式还挺深刻的。

代码:

#include <bits/stdc++.h>
using namespace std;
#define int __int128
const int maxn = 2e5 + 5;
int read() {
	int sum = 0; char c = getchar(), f = 1;
	while(!isdigit(c)) {
		if(c == '-')
			f = -1;
		c = getchar();
	}
	while(isdigit(c))
		sum = sum * 10 + c - '0', c = getchar();
	return sum * f;
}
void write(int x) {
	if(x < 0) {
		putchar('-');
		x = -x;
	}
	if(x <= 9) {
		putchar('0' + x);
		return ;
	}
	write(x / 10);
	putchar(x % 10 + '0');
}
struct node {
	int val, cnt;
	friend node operator+(node x, node y) {
		return node {x.val + y.val, x.cnt + y.cnt};
	}
	friend bool operator<(node x, node y) {
		return (x.val != y.val ? x.val < y.val : x.cnt < y.cnt);
	}
} dp[maxn];
int n, a[maxn], s[maxn], cnt, k;
struct Seg {
	int l, r, p;
} q[maxn];
int fa, tl;
int cal(int l, int r) {
	return (s[r] - s[l]) * (r - l);
}
node ans;
bool chk(int x) {
	dp[0] = {0, 0}; 
	fa = tl = 1;
	q[tl++] = Seg{1, cnt, 0};
	for (int i = 1; i <= cnt; i++) {
		while(q[fa].r < i)
			fa++;
		int j = q[fa].p;
		dp[i] = dp[j] + node{cal(j, i) + x, 1};
	//	cout << j << " " << i << " " << cnt << " " << dp[i].val << endl;
		while(fa < tl) {
	//		if(n == 3)
	//			cout << cnt << endl;
			Seg t = q[tl - 1];
			if(node{cal(t.p, t.r), 1} + dp[t.p] < node{cal(i, t.r), 1} + dp[i]) {
				q[tl++] = Seg{t.r + 1, cnt, i};
				break;
			}
			else if(node{cal(i, t.l), 1} + dp[i] < node{cal(t.p, t.l), 1} + dp[t.p])
				tl--;
			else {
				int l = t.l - 1, r = t.r;
				while(l + 1 < r) {
					int mid = l + r >> 1;
					if(node{cal(i, mid), 1} + dp[i] < node{cal(t.p, mid), 1} + dp[t.p])
						r = mid;
					else
						l = mid;
				}
				q[tl - 1].r = l; q[tl++] = Seg{r, cnt, i};
				break;
			}
		}
		if(fa == tl)
			q[tl++] = Seg{i + 1, cnt, i};
	}
//	cout << cnt << endl;
	ans.val = 1e36;
	for (int i = 0; i <= cnt; i++)
		ans = min(ans, dp[i] + node{(s[n] - s[i]) * (n - i) + x, 1});
	return ans.cnt <= k;
}
void solve() {
	n = read(), k = read(); cnt = 0;
	for (int i = 1; i <= n; i++)
		a[i] = read(), cnt += (a[i] >= 0);
	//cout << "adsfsalkj" << cnt << endl;
	sort(a + 1, a + n + 1); reverse(a + 1, a + n + 1);
	for (int i = 1; i <= n; i++)
		s[i] = s[i - 1] + a[i];
	if(k <= cnt + 1) {
		int l = -1e20, r = 1e20;
	//	chk(2);
		while(l + 1 < r) {
			int mid = l + r >> 1;
			if(chk(mid))
				r = mid;
			else
				l = mid;
		}
		chk(r);
		//cout << ans.val << " laskgjsagd" << ans.cnt << endl;
		write(ans.val - k * r); putchar('\n');
	}
	else {
		int ans = 0;
		for (int i = 1; i <= k - 1; i++)
			ans += a[i];
		ans += cal(k - 1, n);
		write(ans); putchar('\n');
	}
}
signed main() {
	int T; T = read();
	while(T--)
		solve();
	return 0;
}
posted @ 2026-01-12 21:55  LUlululu1616  阅读(1)  评论(0)    收藏  举报