CHT 另解

引入

CHT 又叫 凸包优化,是一种利用一次函数(斜率)来优化 Dp 的一种方法。
它的独特之处在于,传统斜率优化依靠的是一个一个的点,而凸包优化是利用一条条直线来优化,省去了一些码量。
我们用一道例题引入。

例1 HDU-3480

Dp 暴力

Link
题目是说,将 \(n\) 个数划分到 \(m\) 个集合中,使得 \(Cost\) 最小。

\[Cost = \sum_{i=1}^{m}(\max_{x\in s_i} x - \min_{x\in s_i} x)^2 \]

显然应该给输入的 \(n\) 个数排序。
暴力 Dp,令 \(f_{i,j}\) 表示将 \([1,j]\) 的数放进 \(i\) 个集合里面的最小代价。
转移显然,\(f_{i,j}=\min (f_{i-1,k-1}+(a_j-a_k)^2)\) ,时间复杂度 \(O(n^2m)\)
我们来把式子拆开,看看能不能发现什么?

\[(a_j-a_k)^2+f_{i-1,k-1}=a_j^2+a_i^2-2\times a_j \times a_i+f_{i-1,k-1} \]

将关于 \(j\) 的项放在一起,得到:

\[(a_i^2)+[(-2 \times a_j) \times a_i+(f_{i-1,k-1}+a_j^2)] \]

\(k_j=-2\times a_j\)\(b_j=f_{i-1,k-1}+a_j^2\),则:

\[f_{i,j}=a_i^2+\min(k_j \times a_i+b_j) \]

后面的那一坨式子就是一个一次函数的解析式,这里可以用李超线段树,或者使用今天的凸包优化。

维护凸包

我们可以沿用斜率优化的思想,将他们维护成一个上凸包,直线斜率单调递减,维护出来如下图所示。
image
那么在查询答案和插入时又怎么办呢?

队头维护

我们将目前的直线压入队列,因为 \(a_i\) 单调递增,所以查询值从队头开始,反之,插入从队尾插入。
在什么情况下,对头的直线或队尾直线被弹出?继续往下看。
image
\(x\) (也就是 \(a_i\))位于绿色直线的时候,红色直线提供最大值。image
但是当 \(x\) (也就是 \(a_i\)) 逐渐变大时,红色在某一时刻不提供最小值了,于是就可以弹出红色直线。

队尾维护

image
当出现这样的情况(紫色为新插入直线)时,蓝色直线可以弹出,因为紫色的函数在与红色的交点以右的地方都比蓝色小,蓝色不可能提供最小值了。所以应当弹出。
image
反之,三条直线都有可能提供最小值,故保留。
话不多说,看代码:

#include <bits/stdc++.h>
#define FASTIO ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;
using ll = long long;
using pii = pair<int, int>;
const int N = 1e4 + 5;
const double eps = 1e-8;
struct line {
	ll k, b;
	ll f(ll x) {
		return k * x + b;
	}
	double X(const line &o) {
		return 1.0 * (b - o.b) / (o.k - k);
	}
};
deque<line> st;
ll dp[2][N], sum[N], a[N];
int n, m;
void solve(int ncnt) {
	cin >> n >> m;
	for (int i = 1; i <= n; ++i) {
		cin >> a[i]; 
	}
	sort(a + 1, a + n + 1);
	n = unique(a + 1, a + n + 1) - a - 1;
	for (int i = 1; i <= n; ++i) 
		dp[0][i] = (a[i] - a[1]) * (a[i] - a[1]);
	for (int i = 2; i <= m; ++i) {
		st.clear(); 
		st.push_back({0, 0});
		for (int j = 1; j <= n; ++j) {
			while (st.size() >= 2 && st[0].f(a[j]) > st[1].f(a[j])) st.pop_front();
			dp[1][j] = a[j] * a[j] + st.front().f(a[j]);
			line cur = {-2 * a[j + 1], a[j + 1] * a[j + 1] + dp[0][j]};
			while (st.size() >= 2 && st.back().X(st[st.size() - 2]) - st[st.size() - 2].X(cur) > eps) st.pop_back();
			st.push_back(cur);
		}
		for (int j = 1; j <= n; ++j) {
			dp[0][j] = dp[1][j];
		}
	}
	cout << "Case " << ncnt << ": " << dp[1][n] << '\n';
}
int main() {
	FASTIO;
	int t; cin >> t;
	int ncnt = 0;
	while (t --) 
		solve(++ncnt);
	return 0;
}
posted @ 2025-03-16 08:31  tanjiaqi  阅读(143)  评论(0)    收藏  举报