题解:ARC180D Division into 3

原题链接

简述题意

给定一个长度为 \(n\) 的整数序列 \(a\)。现有 \(q\) 组询问 \((l_i,r_i)\),你需要将 \(a_{l_i}\sim a_{r_i}\) 分成三个连续子段,使得每个子段的最大值之和最小。

题解

容易观察到最大值 \(a_p\) 必然会造成贡献,并且其所在的区间无论如何向外延伸,最大值都不变。所以我们考虑分类讨论最大值处于那个子段中。

有一个显然的性质,我们后面会用到,那就是一个区间向外延伸时最大值不减

最大值在最靠左的子段中

考虑枚举后两段的分界点,我们设它介于 \(a_i\)\(a_{i+1}\) 之间。当分界点固定时,显然第二个区间的长度为 1。于是这种情况的答案为

\[\min_{i=p+1}^{r-1}\left\{a_i+\max_{j=i+1}^{r}\{a_j\}\right\} \]

外层显然可以离线扫描线维护。对于内层,我们注意到 max 的部分是后缀最值,因此可以想到用单调栈维护。我们令 \(f_i=a_i+\max_{j=i+1}^{r}\{a_j\}\),利用离线扫描线配合单调栈,维护 \(r\) 固定时所有 \(f_i(1\leq i<r)\) 的值。

\[a_l\cdots a_{s_{top-1}}\underbrace{a_{s_{top-1}+1}\cdots a_{s_{top}}}_{\text{后缀最值取}a_{s_{top}}}\cdots a_{i-1}a_i \]

具体来说,在单调栈 \(s\) 的弹栈过程中,有 \(a_{s_{top}}<a_i\),此时如上图所示,\(a_{s_{top-1}+1}\sim a_{s_{top}}\) 的位置在计算后缀最值时取 \(a_{s_{top}}\),所以这部分所对应的 \(f\) 值减去 \(a_{s_{top}}\),再加上 \(a_i\),就被正确地更新了。形式化地,我们对 \(f\) 开一个支持区间加、区间查询最小值的线段树,每次弹栈时(注意 \(f\) 下标的更新范围,与前面的范围相比有所偏移)

\[\forall i\in[s_{top-1},s_{top}-1],\, f_i\leftarrow f_i-a_{s_{top}}+a_i \]

我们还要单独更新 \(f_{i-1}\),即令 \(f_{i-1}\leftarrow a_{i-1}+a_i\)。更新询问的答案时,直接在线段树上做区间查询即可。
维护单调栈和扫描线的总时间复杂度为 \(O((n+q)\log n)\)

最大值在中间的子段中

显然此时两边的子段长度都为 1,所以这种情况的答案为 \(a_l+a_r+a_p\)

最大值在最靠右的子段中

这种情况和第一种情况完全对称,翻转序列重新做一次情况一的计算,或者对左边界再次进行扫描线即可。


查询询问区间的最大值可以使用 ST 表或者再开一个线段树维护。整体时间复杂度 \(O((n+q)\log n)\)。具体实现时,注意一些边界问题,还有取询问区间的最大值时要对应取最靠左/右边的。

代码

#include <iostream>
#include <algorithm>
#include <vector>
#include <cmath>

using namespace std;

#define lowbit(x) ((x) & -(x))
#define add_mod(x, v) (x) = ((ll)(x) + (v)) % MOD
#define mul_mod(x, v) (x) = (1ll * (x) * (v)) % MOD
#define sub_mod(x, v) (x) = (((ll)(x) - (v)) % MOD + MOD) % MOD
typedef long long ll;
typedef pair<int, int> pii;
const int MAX_N = 2.5e5 + 5, MAX_Q = 2.5e5 + 5, MAX_LOGN = 18 + 5;

int n, q, a[MAX_N];
int top, stk[MAX_N];
ll ans[MAX_Q];

struct Query {
    int id, l, r, p1, p2;
} queries[MAX_Q];

vector<Query> ql[MAX_N], qr[MAX_N];

struct Node {
    ll min_val = 1e18, lz;
};

struct SegmentTree {
#define ls(p) (p << 1)
#define rs(p) ((p << 1) | 1)
    Node nodes[MAX_N << 2];

    void push_up(int p) { nodes[p].min_val = min(nodes[ls(p)].min_val, nodes[rs(p)].min_val); }

    void push_down(int p) {
        if (!nodes[p].lz) return;
        nodes[ls(p)].min_val += nodes[p].lz;
        nodes[ls(p)].lz += nodes[p].lz;
        nodes[rs(p)].min_val += nodes[p].lz;
        nodes[rs(p)].lz += nodes[p].lz;
        nodes[p].lz = 0;
    }

    void clear() {
        for (int i = 1; i <= (n << 2); ++i) {
            nodes[i].min_val = 1e18;
            nodes[i].lz = 0;
        }
    }

    void add(int p, int l, int r, int x, int y, ll v) {
        if (x <= l && y >= r) {
            if (nodes[p].min_val == 1e18) nodes[p].min_val = v;
            else nodes[p].min_val += v;
            nodes[p].lz += v;
            return;
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if (x <= mid) add(ls(p), l, mid, x, y, v);
        if (y > mid) add(rs(p), mid + 1, r, x, y, v);
        push_up(p);
    }

    ll query(int p, int l, int r, int x, int y) {
        if (x <= l && y >= r) return nodes[p].min_val;
        push_down(p);
        int mid = (l + r) >> 1;
        ll res = 1e18;
        if (x <= mid) res = query(ls(p), l, mid, x, y);
        if (y > mid) res = min(res, query(rs(p), mid + 1, r, x, y));
        return res;
    }
#undef ls
#undef rs
} sgt;

struct SparseTable {
    ll f[MAX_N][MAX_LOGN][2];

    void init() {
        for (int i = 1; i <= n; ++i) f[i][0][0] = f[i][0][1] = i;
        int l = log2(n);
        for (int i = 1; i <= l; ++i)
            for (int j = 1; j + (1 << i) - 1 <= n; ++j) {
                if (a[f[j][i - 1][0]] > a[f[j + (1 << (i - 1))][i - 1][0]]) {
                    f[j][i][0] = f[j][i - 1][0];
                    f[j][i][1] = f[j][i - 1][1];
                } else if (a[f[j][i - 1][0]] < a[f[j + (1 << (i - 1))][i - 1][0]]) {
                    f[j][i][0] = f[j + (1 << (i - 1))][i - 1][0];
                    f[j][i][1] = f[j + (1 << (i - 1))][i - 1][1];
                } else {
                    f[j][i][0] = min(f[j][i - 1][0], f[j + (1 << (i - 1))][i - 1][0]);
                    f[j][i][1] = max(f[j][i - 1][1], f[j + (1 << (i - 1))][i - 1][1]);
                }
            }
    }

    pii query(int l, int r) {
        int k = log2(r - l + 1);
        int x = f[l][k][0], y = f[r - (1 << k) + 1][k][0];
        if (a[x] > a[y]) return { x, f[l][k][1] };
        if (a[x] < a[y]) return { y, f[r - (1 << k) + 1][k][1] };
        return { min(x, y), max(f[l][k][1], f[r - (1 << k) + 1][k][1]) };
    }
} st;

int main() {
    ios::sync_with_stdio(false); cin.tie(nullptr);
    cin >> n >> q;
    for (int i = 1; i <= n; ++i) cin >> a[i];
    st.init();
    for (int i = 1; i <= q; ++i) {
        cin >> queries[i].l >> queries[i].r;
        queries[i].id = i;
        pii p = st.query(queries[i].l, queries[i].r);
        queries[i].p1 = p.first;
        queries[i].p2 = p.second;
        ans[i] = 1e18;
        ql[queries[i].l].push_back(queries[i]);
        qr[queries[i].r].push_back(queries[i]);
    }
    st.init();
    for (int i = 1; i <= q; ++i) {
        int l = queries[i].l, r = queries[i].r;
        int p = st.query(l + 1, r - 1).first;
        ans[i] = a[l] + a[p] + a[r];
    }
    for (int i = 1; i <= n; ++i) {
        while (top && a[stk[top]] < a[i]) {
            if (top > 1) sgt.add(1, 1, n, stk[top - 1], stk[top] - 1, a[i] - a[stk[top]]);
            else if (stk[top] > 1) sgt.add(1, 1, n, 1, stk[top] - 1, a[i] - a[stk[top]]);
            --top;
        }
        stk[++top] = i;
        if (i > 1) sgt.add(1, 1, n, i - 1, i - 1, a[i - 1] + a[i]);
        for (auto query : qr[i]) {
            int l = query.l, r = query.r, id = query.id, p = query.p1;
            if (p + 1 <= r - 1) ans[id] = min(ans[id], a[p] + sgt.query(1, 1, n, p + 1, r - 1));
        }
    }
    top = 0;
    sgt.clear();
    for (int i = n; i; --i) {
        while (top && a[stk[top]] < a[i]) {
            if (top > 1) sgt.add(1, 1, n, stk[top] + 1, stk[top - 1], a[i] - a[stk[top]]);
            else if (stk[top] < n) sgt.add(1, 1, n, stk[top] + 1, n, a[i] - a[stk[top]]);
            --top;
        }
        stk[++top] = i;
        if (i < n) sgt.add(1, 1, n, i + 1, i + 1, a[i + 1] + a[i]);
        for (auto query : ql[i]) {
            int l = query.l, r = query.r, id = query.id, p = query.p2;
            if (p - 1 >= l + 1) ans[id] = min(ans[id], a[p] + sgt.query(1, 1, n, l + 1, p - 1));
        }
    }
    for (int i = 1; i <= q; ++i) cout << ans[i] << '\n';
    return 0;
}
posted @ 2024-11-22 22:33  P2441M  阅读(29)  评论(0)    收藏  举报