【四边形不等式优化DP】luogu_P4072 [SDOI2016]征途

题意

给出\(N\)个数,将它们划分成\(M\)段,使每段之和的方差最小,求出这个方差\(*m^2\)

思路

化简一下答案式子,即\(m*\sum_{i=1}^{m}s_i^2-sum^2\),所以要求的就是每段之和的平方和最小。
\(f_{i,j}\)为划分了\(i\)段用了\(j\)个数的最小平方和。
\(f_{i,j}=f_{i-1,k}+(s_j-s_k)^2\),发现右边的这个平方是满足四边形不等式的,所以可以维护决策。
具体地,维护当前和下一次转移的决策。

代码

#include <cstdio>
#include <algorithm>

struct node {
	int l, r, p;
}s[2][50001];
int n, m, now;
int a[3001], top[2];
long long f[2][3001];

long long val(int p, int j, int i) {
	return f[p][j] + (long long)(a[i] - a[j]) * (a[i] - a[j]);
}

int find(int p) {
	int l = s[now ^ 1][top[now ^ 1]].l, r = n;
	while (l < r) {
		int mid = l + r >> 1;
		if (val(now, p, mid) > val(now, s[now ^ 1][top[now ^ 1]].p, mid))
			l = mid + 1;
		else
			r = mid;
	}
	return l;
}

int find1(int p) {
	int l = 1, r = top[now];
	while (l < r) {
		int mid = l + r + 1 >> 1;
		if (s[now][mid].l <= p)
			l = mid;
		else
			r = mid - 1;
	}
	return l;
}

int main() {
	scanf("%d %d", &n, &m);
	for (int i = 1; i <= n; i++)
		scanf("%d", &a[i]), a[i] += a[i - 1];
	s[now][top[now] = 1] = (node){1, n, 0};//current
	for (int i = 1; i <= m; i++, now ^= 1) {
		s[now ^ 1][top[now ^ 1] = 1] = (node){1, n, i};//next
		for (int j = i; j <= n; j++) {
			f[now][j] = val(now ^ 1, s[now][find1(j)].p, j);
			while (top[now ^ 1] && j < s[now ^ 1][top[now ^ 1]].l && val(now, j, s[now ^ 1][top[now ^ 1]].l) < val(now, s[now ^ 1][top[now ^ 1]].p, s[now ^ 1][top[now ^ 1]].l))
				top[now ^ 1]--;
			int u = find(j);
			if (val(now, j, u) > val(now, s[now ^ 1][top[now ^ 1]].p, u)) continue;
			s[now ^ 1][top[now ^ 1]].r = u - 1;
	        s[now ^ 1][++top[now ^ 1]] = (node){u, n, j};
		}
	}
	printf("%lld", m * f[now ^ 1][n] - a[n] * a[n]);
}

再贴一个用分治法做的

#include <cstdio>
#include <algorithm>

int n, m;
int a[3001];
long long f[3001][3001];

void solve(int p, int l, int r, int L, int R) {
	int mid = l + r >> 1, MID = 0;
	f[p][mid] = 2147483647;
	for (int i = L; i <= R && i < mid; i++) {
		long long temp = f[p - 1][i] + (a[mid] - a[i]) * (a[mid] - a[i]);
		if (temp < f[p][mid])
			f[p][mid] = temp, MID = i;
	}
	if (l < mid)
		solve(p, l, mid - 1, L, MID);
	if (r > mid)
		solve(p, mid + 1, r, MID, R);
}

int main() {
	scanf("%d %d", &n, &m);
	for (int i = 1; i <= n; i++)
		scanf("%d", &a[i]), a[i] += a[i - 1], f[1][i] = a[i] * a[i];
	for (int i = 2; i <= m; i++)
		solve(i, 1, n, 1, n);
	printf("%lld", m * f[m][n] - a[n] * a[n]);
}
posted @ 2021-04-03 11:23  nymph181  阅读(85)  评论(0)    收藏  举报