P2042 [NOI2005]维护数列

$ \color{#0066ff}{ 题目描述 }$

请写一个程序,要求维护一个数列,支持以下 6 种操作:(请注意,格式栏 中的下划线‘ _ ’表示实际输入文件中的空格)

img

\(\color{#0066ff}{输入格式}\)

输入文件的第 1 行包含两个数 N 和 M,N 表示初始时数列中数的个数,M 表示要进行的操作数目。 第 2 行包含 N 个数字,描述初始时的数列。 以下 M 行,每行一条命令,格式参见问题描述中的表格

\(\color{#0066ff}{输出格式}\)

对于输入数据中的 GET-SUM 和 MAX-SUM 操作,向输出文件依次打印结 果,每个答案(数字)占一行。

\(\color{#0066ff}{输入样例}\)

9 8 
2 -6 3 5 1 -5 -3 6 3 
GET-SUM 5 4
MAX-SUM
INSERT 8 3 -5 7 2
DELETE 12 1
MAKE-SAME 3 3 2
REVERSE 3 6
GET-SUM 5 4
MAX-SUM

\(\color{#0066ff}{输出样例}\)

-1
10
1
10

\(\color{#0066ff}{数据范围与提示}\)

你可以认为在任何时刻,数列中至少有 1 个数。

输入数据一定是正确的,即指定位置的数在数列中一定存在。

50%的数据中,任何时刻数列中最多含有 30 000 个数;

100%的数据中,任何时刻数列中最多含有 500 000 个数。

100%的数据中,任何时刻数列中任何一个数字均在[-1 000, 1 000]内。

100%的数据中,M ≤20 000,插入的数字总数不超过 4 000 000 。

\(\color{#0066ff}{题解}\)

这是一道Splay维护序列的经典大毒瘤

每次把一段区间琛到固定位置,然后进行操作

维护一个翻转标记,变成标记,只在kth的时候下放即可

注意最大子段和的维护,老套路,记录l,r,还有总共的答案

细节还是蛮多的

#include<bits/stdc++.h>
#define LL long long
LL in() {
    char ch; LL x = 0, f = 1;
    while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
    for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
    return x * f;
}
const int maxn = 1e6 + 10;
const int inf = 0x3f3f3f3f;
struct Splay {
protected:
    struct olinr {
        int lsum, rsum, ans, sum;
        olinr(int lsum = 0, int rsum = 0, int ans = -inf, int sum = 0): lsum(lsum), rsum(rsum), ans(ans), sum(sum) {}
        olinr &operator = (int x) {
            if(x > 0) lsum = rsum =x;
            else lsum = rsum = x;
            ans = sum = x; 
			return *this;
        }
        friend olinr operator + (const olinr &a, const olinr &b){
            return olinr(std::max(a.lsum, a.sum + b.lsum),
                         std::max(b.rsum,a.rsum + b.sum),
                         std::max(std::max(a.ans, b.ans), a.rsum + b.lsum),
                         a.sum + b.sum);
        }
        void r() { std::swap(lsum, rsum); }
    };
    struct node {
        node *ch[2], *fa;
        int val, bc, rev, siz;
        olinr fuck;
        node(node *fa = NULL, int val = 0, int bc = inf, int rev = 0, int siz = 1):
            fa(fa), val(val), bc(bc), rev(rev), siz(siz) {
                fuck = -inf;
                ch[0] = ch[1] = NULL;
        }
        bool isr() { return this == fa->ch[1]; }
        void trn_rv() { std::swap(ch[0], ch[1]), fuck.r(), rev ^= 1; }
        void trn_bc(int v) {
            val = bc = v;
			if(v >= 0) fuck = v * siz;
			else fuck.sum = v * siz, fuck.lsum = fuck.rsum = fuck.ans = v;

        }
        void dwn() {
            if(bc != inf) {
                rev = 0;
                if(ch[0]) ch[0]->trn_bc(bc);
                if(ch[1]) ch[1]->trn_bc(bc);
                bc = inf;
            }
            if(rev) {
                if(ch[0]) ch[0]->trn_rv();
                if(ch[1]) ch[1]->trn_rv();
                rev = 0;
            }
        }
        void upd() {
            siz = 1;
            fuck=val;
            if(ch[0]) siz+=ch[0]->siz,fuck=ch[0]->fuck+fuck;
            if(ch[1]) siz+=ch[1]->siz,fuck=fuck+ch[1]->fuck;
        }
        int rk() { return ch[0]? ch[0]->siz + 1 : 1; }
    }pool[maxn], *root, *tail;
    std::queue<node *> st;
    int cnt[maxn];
    void rot(node *x) {
        node *y = x->fa, *z = y->fa;
        bool k = x->isr(); node *w = x->ch[!k];
        if(y != root) z->ch[y->isr()] = x;
        else root = x;
        (x->ch[!k] = y)->ch[k] = w;
        (y->fa = x)->fa = z;
        if(w) w->fa = y;
        y->upd(), x->upd();
    }
    void splay(node *o, node *p) {
        while(o->fa != p) {
            if(o->fa->fa != p) rot(o->isr() ^ o->fa->isr()? o : o->fa);
            rot(o);
        }
    }
    void dfs(node *o) {
        if(o->ch[0]) dfs(o->ch[0]);
        if(o->ch[1]) dfs(o->ch[1]);
        st.push(o);
    }
    node *kth(int k) {
        node *o = root;
        while(o->dwn(), o->rk() != k) {
            if(o->rk() < k) k -= o->rk(), o = o->ch[1];
            else o = o->ch[0];
        }
        return o;
    }
    void split(int l, int r) { splay(kth(l), NULL), splay(kth(r + 2), root); }
    void travel(node *o, std::vector<int> &v) {
        if(!o)return;
        o->dwn();
        travel(o->ch[0], v);
        v.push_back(o->val);
        travel(o->ch[1], v);
    }
    node *build(node *fa, int l, int r, int *a) {
        if(l > r) return NULL;
        int mid = (l + r) >> 1;
        node *o;
        if(st.empty()) o = new(tail++) node(fa, a[mid]);
        else o = st.front(), st.pop(), *o = node(fa, a[mid]);
        o->ch[0] = build(o, l, mid - 1, a);
        o->ch[1] = build(o, mid + 1, r, a);
        return o->upd(), o;
    }
public:
    Splay() { tail = pool, root = NULL; }
    int querymax() { return root->fuck.ans; }	
    void build(int l, int r, int *a) { root = build(NULL, l, r, a); }
    void ins(int l, int r, int num) {
        for(int i = 1; i <= num; i++) cnt[i] = in();
        split(l, r); 
        node *o = build(NULL, 1, num, cnt);
        (o->fa = root->ch[1])->ch[0] = o;
        root->ch[1]->upd(), root->upd();
    }
    void del(int l, int r) {
        if(l > r) return;
        split(l, r);
        dfs(root->ch[1]->ch[0]);
        root->ch[1]->ch[0] = NULL;
        root->ch[1]->upd(), root->upd();
    }
    void change(int l, int r, int val) {
        if(l > r) return;
        split(l, r);
        root->ch[1]->ch[0]->trn_bc(val);
        root->ch[1]->upd(), root->upd();
    }
    void reverse(int l, int r) {
        if(l > r) return;
        split(l, r);
        root->ch[1]->ch[0]->trn_rv();
		root->ch[1]->upd(), root->upd();
    }
    int getsum(int l, int r) {
        if(l > r) return 0;
        split(l, r);
        return root->ch[1]->ch[0]->fuck.sum;
    }
    void outlist() {
        std::vector<int> v;
        travel(root, v);
        for(auto &i:v) printf("%d ", i);
        puts("");
    }
}s;
int n, m;
int a[maxn];
char v[100];
int main() {
    n = in(), m = in();
    for(int i = 1; i <= n; i++) a[i] = in();
    a[0] = a[n + 1] = -inf;
    s.build(0, n + 1, a);
    int x, y, z;
    while(m --> 0) {
        scanf("%s", v);
        if(v[0] == 'I') x = in(), y = in(), s.ins(x + 1, x, y);
        if(v[0] == 'D') x = in(), y = in(), s.del(x, x + y - 1);
        if(v[0] == 'M') {
            if(v[2] == 'K') x = in(), y = in(), z = in(), s.change(x, x + y - 1, z);
            else printf("%d\n", s.querymax());
        }
        if(v[0] == 'R') x = in(), y = in(), s.reverse(x, x + y - 1);
        if(v[0] == 'G') x = in(), y = in(), printf("%d\n", s.getsum(x, x + y - 1));
    }
    return 0;
}
/* 2 -6 3 5 1 -5 -3 6 -5 7 2*/

posted @ 2019-02-27 21:32  olinr  阅读(161)  评论(0编辑  收藏  举报