【四边形不等式优化DP】luogu_P4072 [SDOI2016]征途
题意
给出\(N\)个数,将它们划分成\(M\)段,使每段之和的方差最小,求出这个方差\(*m^2\)
思路
化简一下答案式子,即\(m*\sum_{i=1}^{m}s_i^2-sum^2\),所以要求的就是每段之和的平方和最小。
设\(f_{i,j}\)为划分了\(i\)段用了\(j\)个数的最小平方和。
\(f_{i,j}=f_{i-1,k}+(s_j-s_k)^2\),发现右边的这个平方是满足四边形不等式的,所以可以维护决策。
具体地,维护当前和下一次转移的决策。
代码
#include <cstdio>
#include <algorithm>
struct node {
int l, r, p;
}s[2][50001];
int n, m, now;
int a[3001], top[2];
long long f[2][3001];
long long val(int p, int j, int i) {
return f[p][j] + (long long)(a[i] - a[j]) * (a[i] - a[j]);
}
int find(int p) {
int l = s[now ^ 1][top[now ^ 1]].l, r = n;
while (l < r) {
int mid = l + r >> 1;
if (val(now, p, mid) > val(now, s[now ^ 1][top[now ^ 1]].p, mid))
l = mid + 1;
else
r = mid;
}
return l;
}
int find1(int p) {
int l = 1, r = top[now];
while (l < r) {
int mid = l + r + 1 >> 1;
if (s[now][mid].l <= p)
l = mid;
else
r = mid - 1;
}
return l;
}
int main() {
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++)
scanf("%d", &a[i]), a[i] += a[i - 1];
s[now][top[now] = 1] = (node){1, n, 0};//current
for (int i = 1; i <= m; i++, now ^= 1) {
s[now ^ 1][top[now ^ 1] = 1] = (node){1, n, i};//next
for (int j = i; j <= n; j++) {
f[now][j] = val(now ^ 1, s[now][find1(j)].p, j);
while (top[now ^ 1] && j < s[now ^ 1][top[now ^ 1]].l && val(now, j, s[now ^ 1][top[now ^ 1]].l) < val(now, s[now ^ 1][top[now ^ 1]].p, s[now ^ 1][top[now ^ 1]].l))
top[now ^ 1]--;
int u = find(j);
if (val(now, j, u) > val(now, s[now ^ 1][top[now ^ 1]].p, u)) continue;
s[now ^ 1][top[now ^ 1]].r = u - 1;
s[now ^ 1][++top[now ^ 1]] = (node){u, n, j};
}
}
printf("%lld", m * f[now ^ 1][n] - a[n] * a[n]);
}
再贴一个用分治法做的
#include <cstdio>
#include <algorithm>
int n, m;
int a[3001];
long long f[3001][3001];
void solve(int p, int l, int r, int L, int R) {
int mid = l + r >> 1, MID = 0;
f[p][mid] = 2147483647;
for (int i = L; i <= R && i < mid; i++) {
long long temp = f[p - 1][i] + (a[mid] - a[i]) * (a[mid] - a[i]);
if (temp < f[p][mid])
f[p][mid] = temp, MID = i;
}
if (l < mid)
solve(p, l, mid - 1, L, MID);
if (r > mid)
solve(p, mid + 1, r, MID, R);
}
int main() {
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++)
scanf("%d", &a[i]), a[i] += a[i - 1], f[1][i] = a[i] * a[i];
for (int i = 2; i <= m; i++)
solve(i, 1, n, 1, n);
printf("%lld", m * f[m][n] - a[n] * a[n]);
}

浙公网安备 33010602011771号