题解 CF631E Product Sum
题意
给定一个长度为 $n$ 的序列 $a$,定义该序列的权值为 $\sum\limits_{i=1}^n i \cdot a_i$,现有一次将序列中某个数移动到任意位置的机会,可以移回原位,求操作后序列 $a$ 最大权值。
数据范围:$2 \leq n \leq 2\times10^{5},|a_i|\le10^6$
题解
考虑先算出未操作时的权值,再去算操作后能使答案增加的最大值。
设 $f_i$ 表示移动 $a_i$ 对答案的最大贡献值,若 $a_i$ 移到了位置 $j$,我们分类讨论:
- 当 $i< j$ 时,$f_i=\max\{a_i\times(j-i)-\sum\limits_{k=i+1}^ja_k\}$
- 当 $i\ge j$ 时,$f_i=\max\{\sum\limits_{k=j}^{i-1}a_k-a_i\times(i-j)\}$
后面那个求和可以用前缀和弄掉,设 $sum_i=\sum\limits_{j=1}^ia_j$,则:
$$f_i=\begin{cases}\max\{a_i\times(j-i)-sum_j+sum_i\}&j\in\left(i,n\right]\\\max\{a_i\times(j-i)-sum_{j-1}+sum_{i-1}\}& j\in\left[1,i\right]\end{cases}$$
于是我们得到一个 $O(n^2)$ 的做法。
其实化成这样也可以做,但比较麻烦。
观察到上下两个形式很相近,考虑当 $j\in\left[1,i\right]$,令 $j\gets j-1$ 转化成和上面一样的形式。
$$\begin{aligned}f_i&=\max\{a_i\times(j-i)-sum_{j-1}+sum_{i-1}\}\ \ \ \ \ \ \ j\in\left[1,i\right]\\&=\max\{a_i\times(j+1-i)-sum_{j}+sum_{i-1}\}\ \ \ \ \ \ \ j\in\left[0,i\right)\\&=\max\{a_i\times(j+1-i)-sum_{j}+sum_{i}-a_i\}\ \ \ \ \ \ \ j\in\left[0,i\right)\\&=\max\{a_i\times(j-i)-sum_{j}+sum_{i}\}\ \ \ \ \ \ \ j\in\left[0,i\right)\end{aligned}$$
发现 $j\in\left[0,i\right)$ 和 $j\in\left(i,n\right]$ 时的计算方式一样,我们便大大简化了过程。
于是 $f_i=\max\{a_i\times(j-i)-sum_j+sum_i\}$。
拆开后 $f_i=\max\{a_i\times j-sum_j\}+sum_i-a_i\times i$。
这个形式显然可以斜率优化,若 $j_1<j_2$ 且 $i$ 移到 $j_2$ 时比移到 $j_1$ 时贡献要大,有:$$a_i\times j_1-sum_{j_1}<a_i\times{j_2}-sum_{j_2}$$$$\therefore a_i\times (j_1-j_2)<sum_{j_1}-sum_{j_2}$$$$\therefore a_i>\dfrac{sum_{j_1}-sum_{j_2}}{j_1-j_2}$$
维护 $(j,sum_{j})$ 这些点组成的下凸壳即可,因为点坐标与 $f_j$ 无关,可以先把这些点加进来,注意 $a_i$ 不一定是单增的,可以有两种处理方式。
-
二分下凸壳,找到比 $a_i$ 小的最大斜率。
-
因为答案与计算顺序无关,先将 $a$ 排序后再上单调队列。
这两种我都写了,二分复杂度 $O(n\log n)$,排序后单调队列可以用基排做到 $O(n)$,不过我懒得写了,貌似单调队列的做法会快(?
只贴单调队列的代码了。
#include <bits/stdc++.h>
#define ll long long
#define ld long double
using namespace std;
const int N = 1e6 + 10;
int n;
int q[N], head, tail;
ll sum[N], ans, base;
struct node {
int v, id;
friend bool operator < (const node &qwq, const node &awa) {
return qwq.v < awa.v;
}
} a[N];
ld slope(int i, int j) {
return (1. * sum[i] - sum[j]) / (i - j);
}
int main() {
scanf("%d", &n);
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i].v);
sum[i] = sum[i - 1] + a[i].v, base += 1ll * i * a[i].v;
a[i].id = i;
}
head = tail = 1;
for(int i = 1; i <= n; i++) {
while(tail > 1 && slope(q[tail - 1], i) < slope(q[tail - 1], q[tail])) tail--;
q[++tail] = i;
}
sort(a + 1, a + 1 + n);
for(int i = 1; i <= n; i++) {
while(head < tail && slope(q[head], q[head + 1]) < a[i].v) head++;
ans = max(ans, 1ll * (q[head] - a[i].id) * a[i].v + sum[a[i].id] - sum[q[head]]);
}
printf("%lld\n", ans + base);
}

浙公网安备 33010602011771号