Codeforces Round #783 D. Optimal Partition
题目大意
给定一个长度为 \(n(1\leq n\leq5\times10^5)\) 的数列,每个数字 \(a_i(-10^9\leq a_i\leq10^9)\)。可以将数列划分为若干非空的连续子段 \([l,r]\) ,设 \(s\) 为 \(\sum_{i=l}^r a_i\),则该子段的价值为 \((r-l+1)sgn(s)\),求能够划分出的字段价值的和的最大值。
思路
考虑 \(dp\) ,设 \(f_i\) 为考虑前 \(i\) 个数是的最大价值,前缀和为 \(S\) ,初值设 \(f_0=0\) ,很容易得到转移方程
\[f_i=max_{0\leq j<i}\{f_j+sgn(S_i-S_j)(i-j)\}
\]
我们可以把与 \(j\) 有关的项都提出来,即
\[f_i=max\{max_{0\leq j<i,S_i>S_j}\{(f_j-j)+i\},max_{0\leq j<i,S_i<S_j}\{(f_j+j)-i\},max_{0\leq j<i,S_i=S_j}\{f_j\}\}
\]
于是我们可以建立两个维护最大值的树状数组,首先对前缀和的值域进行离散化,分别维护\(S_i>S_j\) 时的 \(f_j-j\) 和 \(S_i<S_j\) 时的 \(f_j+j\) ,然后每次转移的时候分别查询取最大值,之后再处理一下\(S_i=S_j\) 的情况,直接用一个数组维护一下每个 \(S_i\) 下最大的 \(f\) 即可。最后 \(f_n\) 即为答案,复杂度 \(O(nlogn)\)。
代码
#include<bits/stdc++.h>
#include<unordered_map>
#include<unordered_set>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using PII = pair<int, int>;
using TP = tuple<int, int, int>;
#define all(x) x.begin(),x.end()
#define pb push_back
//#define int LL
//#define lc p*2
//#define rc p*2+1
#define endl '\n'
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
//#pragma warning(disable : 4996)
#define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0)
const double eps = 1e-8;
const LL MOD = 1000000007;
const LL mod = 998244353;
const int maxn = 500010;
LL T, N, A[maxn], S[maxn], B[maxn], f[maxn], mx[maxn];
LL n, dat[2][maxn];
void init(int x)
{
n = x;
for (int i = 0; i <= x; i++)
dat[0][i] = dat[1][i] = -INF;
for (int i = 0; i <= N + 1; i++)
f[i] = mx[i] = -INF;
}
void modify(LL i, LL x, LL t)
{
while (i <= n)
{
dat[t][i] = max(dat[t][i], x);
i += i & (-i);
}
}
LL query(LL i, LL t)
{
LL ans = -INF;
while (i)
{
ans = max(ans, dat[t][i]);
i -= i & (-i);
}
return ans;
}
int compress(LL* ar)
{
vector<LL>xs;
for (int i = 0; i <= N; i++)
xs.pb(ar[i]);
sort(all(xs));
xs.erase(unique(all(xs)), xs.end());
for (int i = 0; i <= N; i++)
B[i] = upper_bound(all(xs), ar[i]) - xs.begin();
return xs.size();
}
void solve()
{
for (int i = 1; i <= N; i++)
S[i] = S[i - 1] + A[i];
int M = compress(S);
init(M);
for (int i = 0; i <= N; i++)
{
if (i == 0)
f[i] = 0;
else
f[i] = max(mx[B[i]], max(query(B[i] - 1, 0) + i, query(M - B[i], 1) - i));
modify(B[i], f[i] - i, 0), modify(M - B[i] + 1, f[i] + i, 1);
mx[B[i]] = max(f[i], mx[B[i]]);
}
cout << f[N] << endl;
}
int main()
{
IOS;
cin >> T;
while (T--)
{
cin >> N;
for (int i = 1; i <= N; i++)
cin >> A[i];
solve();
}
return 0;
}