Maximum Subarray
提供一个复杂度与 无关的算法。
容易发现的是,当 时,我们令 , ,与原题等价。
所以下文我们默认 。
我们先考虑最终选择所有长度 的子段,那么这个子段每一个数都可以加上 ,且应该加上 ,因为 。
所以这一部分我们相当于要求 。
容易发现我们只需要枚举所有长度 的区间,求一个最大子段和取 即可,可以线段树维护。
接着第二部分,长度 ,容易发现,对于一个区间,若长度 ,那么具体选哪些点 是没有意义的,只需要关心 的数有 个。
于是我们枚举所有长度为 的区间,看成这个区间里每个数都加 ,哪么两边每个数 的贡献是 ,只需要处理从这个区间左端点能往前走最多的和右端点往右走最多的,容易发现可以线性处理。
所以总体复杂度 。
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 2e5 + 5;
int t, n, k, x;
int a[N];
class SegmentTree
{
public:
struct Node
{
int l, r;
int sum, lmax, rmax, tmax;
Node()
{
l = r = sum = 0;
lmax = rmax = tmax = -1e11;
}
friend Node operator+(const Node& x, const Node& y)
{
Node p;
p.sum = x.sum + y.sum;
p.lmax = max(x.lmax, x.sum + y.lmax);
p.rmax = max(y.rmax, y.sum + x.rmax);
p.tmax = max({ x.tmax, y.tmax, x.rmax + y.lmax });
return p;
}
};
Node tree[N << 2];
void CLEAR(int n)
{
for (int i = 0; i <= 4 * n; i++)
{
tree[i].l = tree[i].r = 0;
tree[i].sum = 0;
tree[i].lmax = tree[i].rmax = tree[i].tmax = -1e11;
}
}
inline void push_up(int u)
{
int tl = tree[u].l, tr = tree[u].r;
tree[u] = tree[u << 1] + tree[u << 1 | 1];
tree[u].l = tl, tree[u].r = tr;
}
inline void build(int u, int l, int r, int *a)
{
tree[u].l = l, tree[u].r = r;
if (l == r)
{
tree[u].lmax = tree[u].rmax = tree[u].sum = tree[u].tmax = a[r];
}
else
{
int mid = (l + r) >> 1;
build(u << 1, l, mid, a);
build(u << 1 | 1, mid + 1, r, a);
push_up(u);
}
}
inline void modify(int u, int x, int v)
{
if (tree[u].l == x && tree[u].r == x)
{
tree[u].lmax = tree[u].rmax = tree[u].sum = tree[u].tmax = v;
}
else
{
int mid = (tree[u].l + tree[u].r) >> 1;
if (x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
push_up(u);
}
}
Node query(int u, int l, int r)
{
if (tree[u].l >= l && tree[u].r <= r) return tree[u];
int mid = (tree[u].l + tree[u].r) >> 1;
Node x, y;
if (l <= mid) x = query(u << 1, l, r);
if (r > mid) y = query(u << 1 | 1, l, r);
return x + y;
}
}sgt;
signed main()
{
scanf("%lld", &t);
while (t--)
{
scanf("%lld%lld%lld", &n, &k, &x);
sgt.CLEAR(n);
if (x < 0)
{
x = -x;
k = n - k;
}
int ans = -1e18;
for (int i = 1; i <= n; i++) scanf("%lld", &a[i]);
sgt.build(1, 1, n, a);
int i = k - 1;
if (i <= 0) goto E;
for (int j = 1; j <= i; j++)
{
sgt.modify(1, j, a[j] + x);
}
for (int j = 1; j + i - 1 <= n; j++)
{
ans = max(ans, sgt.query(1, j, j + i - 1).tmax);
sgt.modify(1, j, a[j]);
if (j + i <= n) sgt.modify(1, j + i, a[j + i] + x);
}
E:
long long f1 = 0LL, f2 = 0LL;
long long sum = 0LL, psum = 0LL;
for (int i = k + 1; i <= n; i++)
{
sum += (a[i] - x);
f2 = max(f2, sum);
}
for (int i = 1; i <= k; i++)
{
psum += a[i];
}
for (int i = 1; i + k - 1 <= n; i++)
{
ans = max(ans, psum + k * x + f1 + f2);
//printf("!: %lld %lld %lld\n", i, f1, f2);
f1 += (a[i] - x);
f1 = max(f1, 0LL);
if (f2 > 0) f2 -= (a[i + k] - x);
f2 = max(f2, 0LL);
psum -= a[i];
psum += a[i + k];
}
ans = max(ans, 0LL);
printf("%lld\n", ans);
}
return 0;
}

浙公网安备 33010602011771号