洛谷P2042 [NOI2005] 维护数列 题解 splay tree

题目大意:支持如下几个操作:

  1. 插入:INSERT p tot c1 c2 ... ctot:在第 \(p\) 个数后面插入连续的 \(tot\) 个数字
  2. 删除:DELETE p tot:删除第 \(p\) 个数开始的连续 \(tot\) 个数字
  3. 修改:MAKE-SAME p tot c:将第 \(p\) 个数开始的连续 \(tot\) 个数字全部更新为 \(c\)
  4. 翻转:REVERSE p tot:将第 \(p\) 个数开始的连续 \(tot\) 个数的区间(即:\([p, p+tot)\))进行翻转
  5. 求和:GET-SUM p tot:求第 \(p\) 个数开始的连续 \(tot\) 个数之和
  6. 求最大连续子序列和:MAX-SUM:求(整个)序列的最大连续子序列和

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 5, INF = 1e9;

struct Node {
    int s[2], p, v;
    int rev, same;
    int sz, sum, ms, ls, rs;
    Node() {}
    Node(int _v, int _p) {
        s[0] = s[1] = 0, v= _v, p = _p;
        rev = same = 0;
        sz = 1, sum = ms = v;
        ls = rs = max(v, 0);
    }
} tr[maxn];
int root;
int w[maxn];
queue<int> nodes;

void push_up(int x) {
    auto &u = tr[x], &l = tr[u.s[0]], &r = tr[u.s[1]];
    u.sz = l.sz + r.sz + 1;
    u.sum = l.sum + r.sum + u.v;
    u.ls = max(l.ls, l.sum + u.v + r.ls);
    u.rs = max(r.rs, l.rs + u.v + r.sum);
    u.ms = max(max(l.ms, r.ms), l.rs + u.v + r.ls);
}

void t_same(int x, int c) {
    if (x) {
        tr[x].same = 1;
        tr[x].rev = 0;
        tr[x].v = c;
        tr[x].sum = tr[x].sz * c;
        if (c > 0)
            tr[x].ms = tr[x].ls = tr[x].rs = tr[x].sum;
        else
            tr[x].ms = c, tr[x].ls = tr[x].rs = 0;
    }
}

void t_rev(int x) {
    if (x) {
        tr[x].rev ^= 1;
        swap(tr[x].s[0], tr[x].s[1]);
        swap(tr[x].ls, tr[x].rs);
    }
}

void push_down(int x) {
    if (tr[x].same) {
        tr[x].same = tr[x].rev = 0;
        t_same(tr[x].s[0], tr[x].v);
        t_same(tr[x].s[1], tr[x].v);
    }
    else if (tr[x].rev) {
        tr[x].rev = 0;
        t_rev(tr[x].s[0]);
        t_rev(tr[x].s[1]);
    }
}

void f_s(int p, int u, bool k) {
    tr[p].s[k] = u;
    tr[u].p = p;
}

void rot(int x) {
    int y = tr[x].p, z = tr[y].p;
    bool k = tr[y].s[1] == x;
    f_s(z, x, tr[z].s[1]==y);
    f_s(y, tr[x].s[k^1], k);
    f_s(x, y, k^1);
    push_up(y), push_up(x);
}

void splay(int x, int k) {
    while (tr[x].p != k) {
        int y = tr[x].p, z = tr[y].p;
        if (z != k)
            (tr[y].s[1]==x) ^ (tr[z].s[1]==y) ? rot(x) : rot(y);
        rot(x);
    }
    if (!k) root = x;
}

int get_k(int k) {
    int u = root;
    while (u) {
        push_down(u);
        if (tr[tr[u].s[0]].sz >= k) u = tr[u].s[0];
        else if (tr[tr[u].s[0]].sz + 1 == k) return u;
        else k -= tr[tr[u].s[0]].sz + 1, u = tr[u].s[1];
    }
    return -1;
}

int build(int l, int r, int p) {
    int mid = (l + r) / 2;
    int u = nodes.front(); nodes.pop();
    tr[u] = Node(w[mid], p);
    if (l < mid) tr[u].s[0] = build(l, mid-1, u);
    if (mid < r) tr[u].s[1] = build(mid+1, r, u);
    push_up(u);
    return u;
}

void del(int u) {
    if (!u) return;
    nodes.push(u);
    del(tr[u].s[0]);
    del(tr[u].s[1]);
}

void init() {
    // 初始化 tr[0]
    tr[0] = Node(0, 0);
    tr[0].sz = 0;
    tr[0].ms = -INF;
    // 把所有点放进回收站
    for (int i = 1; i < maxn; i++) nodes.push(i);
}

int n, m;
char op[22];

int main() {
    init();
    scanf("%d%d", &n, &m);
    w[0] = w[n+1] = -INF;   // 设立两个哨兵节点
    for (int i = 1; i <= n; i++) scanf("%d", w+i);
    root = build(0, n+1, 0);
    while (m--) {
        scanf("%s", op);
        if (!strcmp(op, "INSERT")) {
            int p, tot;
            scanf("%d%d", &p, &tot);
            for (int i = 0; i < tot; i++) scanf("%d", w+i);
            int l = get_k(p+1), r = get_k(p+2);
            splay(l, 0), splay(r, l);
            int u = build(0, tot-1, r);
            tr[r].s[0] = u;
            push_up(r), push_up(l);
        }
        else if (!strcmp(op, "DELETE")) {
            int p, tot;
            scanf("%d%d", &p, &tot);
            int l = get_k(p), r = get_k(p+tot+1);
            splay(l, 0), splay(r, l);
            del(tr[r].s[0]);
            tr[r].s[0] = 0;
            push_up(r), push_up(l);
        }
        else if (!strcmp(op, "MAKE-SAME")) {
            int p, tot, c;
            scanf("%d%d%d", &p, &tot, &c);
            int l = get_k(p), r = get_k(p+tot+1);
            splay(l, 0), splay(r, l);
            t_same(tr[r].s[0], c);
            push_up(r), push_up(l);
        }
        else if (!strcmp(op, "REVERSE")) {
            int p, tot;
            scanf("%d%d", &p, &tot);
            int l = get_k(p), r = get_k(p+tot+1);
            splay(l, 0), splay(r, l);
            t_rev(tr[r].s[0]);
            push_up(r), push_up(l);
        }
        else if (!strcmp(op, "GET-SUM")) {
            int p, tot;
            scanf("%d%d", &p, &tot);
            int l = get_k(p), r = get_k(p+tot+1);
            splay(l, 0), splay(r, l);
            printf("%d\n", tr[tr[r].s[0]].sum);
        }
        else {  // MAX-SUM
            printf("%d\n", tr[root].ms);
        }
    }
    return 0;
}
posted @ 2022-12-18 21:29  quanjun  阅读(31)  评论(0)    收藏  举报