CF1787I Treasure Hunt 题解
考虑题目中的这个限制是有些困难的,但是细想,其实这个限制没有意义,要求的就是每个区间的最大可空前缀和加上最大可空子段和的和。
首先我们证明这东西是对的:
考虑反证。假如选的前缀是 ,子段是 。假如这一选择是最优的,且 。考虑区间 的和。如果这个和非负,那我可以令 ,那么值必定不劣并且满足要求。如果这一段和是负数,显然令 。于是这样的区间不可能取到最优值。
把这个事情分两部分考虑。第一个是每个区间的最大前缀和。考虑维护前缀和,那么区间 的贡献即为 。扫描线,从 扫描 ,并维护关于 的单调栈。显然我们可以二分和 ST 表求出最小的 使得存在前缀和非负,然后在单调栈上二分这个位置即可。复杂度 。
考虑第二部分,即每个区间的最大子段和。这东西看着有点困难,我们考虑一些可以合并的东西来维护。考虑分治,当前区间 ,。 和 的贡献递归处理。考虑求 的区间 的最大子段和的贡献。
考虑这个最大子段和是什么?令 表示,如果 , 的最大子段和,否则是 的最大子段和。这个容易用朴素的最大子段和算法求出。另一个是 ,如果 ,表示 最大后缀和,否则表示 最大前缀和。那么容易用 表示 最大子段和。类似线段树维护最大子段和的合并区间。
考虑扫描线, 扫描 。满足 为这三个最大值的 一定是一段前缀。这个很容易双指针线性维护或者二分多一个 维护。我们希望得到更优的复杂度,如果带一个 整体就是两只 了。不太优美。于是这里考虑双指针求出 为最大值的那段前缀。这个是容易的。
但是难点在于,后面两个的最大值,似乎没有什么联系。我们设 ,移项变成 。我们考虑 是什么?事实上,随着 增大, 单调不降。知道这个结论后显然可以再用一个指针维护这一维即可在 的总复杂度内求出答案。
现在我们考虑证明这个 的单调性。
还记得我们在初始时得到的那个结论吗?即最大前缀 和最大子段 不可能是 。那么 只有两种。一种是 ,另一种是 。
我们考虑现在在末尾添加一个数。先考虑第一种,。此时如果 ,即最大前缀和变化,那么必然是目前整体和大于之前的最大前缀和。即 。那么必然有 。那显然最大子段和也可以这样变化,因为 ,所以显然 增加量不小于 增加量。
第二部分是,如果 。前缀和还是只能通过整体和更新。还是有 。由于最大前缀在这次才被更新,那么 ,否则之前的最大前缀应该可以往后。同时也有 。证明是同理的。此时 。那么 变化量也不小于 变化量。
至此,我们得以证明 单调不降,用之前提到的两个双指针即可维护。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include <stack>
#include <map>
using namespace std;
#define int long long
const int N = 1e6 + 5, MOD = 998244353ll;
int t, n, a[N];
int ans = 0ll;
int LG2[N];
int s[N];
class ST
{
public:
int f[N][21];
void Init(int n, int* s)
{
for (int i = 1; i <= n; i++) f[i][0] = s[i];
for (int j = 1; j <= LG2[n]; j++)
{
for (int i = 1; i + (1ll << j) - 1ll <= n; i++) f[i][j] = max(f[i][j - 1], f[i + (1ll << (j - 1))][j - 1]);
}
}
inline int query(int l, int r) const
{
int p = LG2[r - l + 1];
return max(f[l][p], f[r - (1ll << p) + 1ll][p]);
}
}st;
int f1[N], f2[N]; // 最大子段和,最大前缀/后缀和
int sf1[N], sf2[N];
void solve(int l, int r)
{
if (l == r)
{
ans += max(0ll, a[r]);
return;
}
int mid = l + r >> 1;
solve(l, mid);
solve(mid + 1, r);
for (int i = l - 1; i <= r + 1; i++) f1[i] = f2[i] = (int)-4e18, sf1[i] = sf2[i] = 0ll;
// solve f2
f2[mid] = max(0ll, a[mid]);
for (int i = mid - 1; i >= l; i--) f2[i] = max(f2[i + 1], s[mid] - s[i - 1]);
f2[mid + 1] = max(0ll, a[mid + 1]);
for (int i = mid + 2; i <= r; i++) f2[i] = max(f2[i - 1], s[i] - s[mid]);
// solve f1
f1[mid] = a[mid];
for (int i = mid - 1; i >= l; i--)
{
f1[i] = a[i] + max(0ll, f1[i + 1]);
}
f1[mid] = max(0ll, f1[mid]);
for (int i = mid - 1; i >= l; i--)
{
f1[i] = max({ f1[i], f1[i + 1], 0ll });
}
f1[mid + 1] = a[mid + 1];
for (int i = mid + 2; i <= r; i++) f1[i] = a[i] + max(0ll, f1[i - 1]);
f1[mid + 1] = max(0ll, f1[mid + 1]);
for (int i = mid + 2; i <= r; i++) f1[i] = max({ f1[i], f1[i - 1], 0ll });
// conquer
for (int i = l; i <= r; i++) sf1[i] = sf1[i - 1] + f1[i], sf2[i] = sf2[i - 1] + f2[i];
int j1 = mid + 1, j2 = mid + 1;
//cerr << "begin conquer: " << l << " " << mid << " " << r << "\n";
//if (l == 1 && r == 4)
//{
// cerr << "debug: \n";
// for (int i = 1; i <= 4; i++) cerr << f1[i] << " " << f2[i] << "\n";
// cerr << "end debug.\n\n";
//}
for (int i = mid; i >= l; i--)
{
while (j1 <= r && max(f1[j1], f2[j1] + f2[i]) <= f1[i]) j1++;
// [mid+1,j1)
int sumval = 0ll;
int len = max(0ll, j1 - mid - 1);
sumval += len * f1[i] % MOD;
sumval %= MOD;
while (j2 <= r && f1[j2] - f2[j2] < f2[i]) j2++;
// [j1,j2): f2_{l]+f2_{r}
len = max(0ll, j2 - j1);
if (j1 < j2)
{
sumval += len * f2[i] + sf2[j2 - 1] - sf2[j1 - 1];
// [j2,r]: fr
if (j2 <= r) sumval += sf1[r] - sf1[j2 - 1];
}
else
{
// [j1,r]: f1_r
sumval += sf1[r] - sf1[j1 - 1];
}
ans += sumval;
ans %= MOD;
//cerr << "!!!: " << i << " " << j1 << " " << j2 << " " << sumval << "\n";
}
//cerr << "end conquer.\n\n";
//cout << "end.\n\n";
}
signed main()
{
ios::sync_with_stdio(0), cin.tie(0);
for (int i = 2; i < N; i++) LG2[i] = LG2[i >> 1] + 1;
cin >> t;
for (int tc = 1; tc <= t; tc++)
{
ans = 0ll;
cin >> n;
for (int i = 1; i <= n; i++)
{
cin >> a[i], s[i] = s[i - 1] + a[i];
}
st.Init(n, s);
vector<pair<int, int>> v;
for (int i = n; i >= 1; i--)
{
while (v.size() && s[i] > s[v.back().first])
{
v.pop_back();
}
int cnt = (v.size() ? v.back().first - i : n - i + 1);
int val = (v.size() ? v.back().second : 0ll) + cnt * s[i];
v.emplace_back(make_pair(i, val));
int l = i, r = n, pos = -1;
while (l <= r)
{
int mid = l + r >> 1ll;
if (st.query(i, mid) > s[i - 1]) pos = mid, r = mid - 1;
else l = mid + 1;
}
if (~pos)
{
if (!v.size() || v[0].first <= pos)
{
ans += (st.query(i, pos) - s[i - 1]) * (n - pos + 1);
}
else
{
int l = 0, r = v.size() - 1, pos2 = -1;
while (l <= r)
{
int mid = l + r >> 1;
if (v[mid].first >= pos) pos2 = mid, l = mid + 1;
else r = mid - 1;
}
ans += st.query(i, pos) * ((v[pos2]).first - pos) + (v[pos2]).second - (s[i - 1] * (n - pos + 1));
ans %= MOD;
}
}
}
solve(1, n);
// brute force:
/*
for (int i = 1; i <= n; i++)
{
for (int j = 0; j <= n; j++) f1[j] = 0;
int mxn = 0ll;
for (int j = i; j <= n; j++)
{
f1[j] = max(0ll, f1[j - 1]) + a[j];
mxn = max(mxn, f1[j]);
ans += mxn;
}
}*/
cout << ans % MOD << "\n";
}
return 0;
}
/*
1
4
2 -5 -1 3
*/

浙公网安备 33010602011771号