【整体二分】数列切割

\(1\le n\le 10^5,1\le m\le 20\)

首先列出 dp 方程:\(f_{i,j}\) 表示将前 \(j\) 个数分为 \(i\) 段最小代价。

转移 \(f_{i,j}=\min\{f_{i-1,k}+\text{cost}(k+1,j)\}\)

容易发现,当 \(i\) 固定时,\(f_{i,j}\) 具有决策单调性。即 \(i\) 固定时,如果对于 \(j_1\) 最优决策点在 \(k_1\),对于 \(j_2\) 最优决策点在 \(k_2\),若 \(j_1<j_2\),则 \(k_1\le k_2\)

因此可以整体二分。

函数 \(solve(l,r,x,y)\) 表示当前正在计算 \([l,r]\) 区间的 \(f\) 值,它们的最优决策点一定在 \([x,y]\) 之间。(注意此时 \(i\) 已经固定),暴力计算 \(f_{i,mid}\) 的值,然后继续递归。

实现上有一个小(da)问题:如何快速计算区间 \(\text{cost}\)

这玩意儿显然没法 \(O(1)\) 算,而时间和nkoj的速度也不允许 \(O(\log n)\) 计算。

但是注意到整体二分中查询的 \(\text{cost}\) 很有规律,于是可以使用一些鬼畜和巨难写的优化把计算 \(\text{cost}\) 的时间降到 \(O(1)\)

不会吧不会吧不会真的有人去写这么复杂的优化吧

因为整体二分中查询的 \(\text{cost}\) 很有规律,所以我们可以直接像莫队一样暴力移动左右端点啊!

复杂度为什么是 \(O(1)\) 感性理解就行了,严格证明证完了估计人可以送去急救了

这样暴力不仅巨好写还跑得更快,总时间复杂度 \(O(nm\log n)\)

#include <cstdio>
#include <algorithm>
#include <cstring>
#define int long long
#define gc (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 65536, stdin), p1 == p2) ? EOF : *p1 ++)

inline int min(const int x, const int y) {return x < y ? x : y;}
inline int max(const int x, const int y) {return x > y ? x : y;}
char buf[65536], *p1, *p2;
inline int read() {
	int x = 0;
	char ch;
	while ((ch = gc) < 48);
	do x = x * 10 + ch - 48; while ((ch = gc) >= 48);
	return x;
}

int a[100005], dp[21][100005], cnt[100005], now, L = 1, R = 0;

inline void add(int x) {now += (cnt[x] ++);}
inline void del(int x) {now -= (-- cnt[x]);}
inline int getval(int l, int r) {
	while (L > l) add(a[-- L]);
	while (R < r) add(a[++ R]);
	while (L < l) del(a[L ++]);
	while (R > r) del(a[R --]);
	return now;
}

void solve(int p, int l, int r, int x, int y) {
	if (l > r || x > y) return;
	int mid = l + r >> 1, optimal = 0;
	for (int i = min(mid - 1, y); i >= x; -- i) {
		int tmp = getval(i + 1, mid);
		if (!optimal || dp[p - 1][i] + tmp < dp[p][mid]) dp[p][mid] = dp[p - 1][optimal = i] + tmp;
	}
	solve(p, l, mid - 1, x, optimal);
	solve(p, mid + 1, r, optimal, y);
}

signed main() {
	int n = read(), m = read();
	for (int i = 1; i <= n; ++ i) add(a[i] = read()), dp[1][i] = now;
	for (int i = 1; i <= n; ++ i) -- cnt[a[i]];
	now = 0;
	for (int i = 2; i <= m; ++ i) solve(i, 1, n, 1, n);
	printf("%lld", dp[m][n]);
	return 0;
}
posted @ 2022-03-06 16:31  zqs2020  阅读(44)  评论(0编辑  收藏  举报