luoguP4647 [IOI2007] sails 船帆

https://www.luogu.org/problemnew/show/P4647

首先发现答案与顺序无关,令 $ x_i $ 表示高度为 $ i $ 的那一行帆的个数,第 $ i $ 行对答案的贡献为 $ \frac{x_i * (x_i - 1)}{2} $

先把旗杆按照高度从小到大排序,有一个显然的贪心是每次选择能放的地方帆最少的一行放一个帆,最少的一行放一个帆对答案的贡献一定最小,而后面的旗杆高度更高,能选择放帆的地方更多,这样可以保证答案最小(可以感性理解一下

那么我们的做法就出来了,对于旗杆 $ i $ 选择当前能放帆的最少的 $ k_i $ 行,放上帆,这个操作用平衡树实现,将原序列前 $ k $ 小拿出,加上 $ 1 $ 后放回去,这个时候可能会使平衡树的性质被打破,如 $ 0 $ $ 0 $ 将第一个数加 $ 1 $ 再放回去就变成了 $ 1 $ $ 0 $,所以对于分裂出来的前 $ k $ 小中的最大值还要再次分裂,如 $ 0 $ $ 1 $ $ 1 $ $ 1 $ $ 2 $ 前 $ 3 $ 小要加 $ 1 $,先分裂成 $ 0 $ $ 1 $ $ 1 $ 和 $ 1 $ $ 2 $,然后分裂成 $ 0 $ ,$ 1 $ $ 1 $ 和 $ 1 $ $ 2 $,区间加后放回原序列,先变成 $ 1 $ {这里放原来的 $ 1 $ $ 1 $ } $ 2 $,然后变成 {这里放原来的 $ 0 $} $ 1 $ {这里放原来的 $ 1 $ $ 1 $ } $ 2 $,即 $ 0 $ $ 1 $ $ 1 $ $ 1 $ $ 2 $,可以使用 splay 实现

#include <bits/stdc++.h>
#define Fast_cin ios::sync_with_stdio(false), cin.tie(0);
#define rep(i, a, b) for(register int i = a; i <= b; i++)
#define per(i, a, b) for(register int i = a; i >= b; i--)
#define DEBUG(x) cerr << "DEBUG" << x << " >>> " << endl;
using namespace std;

typedef unsigned long long ull;
typedef pair <int, int> pii;
typedef long long ll;

template <typename _T>
inline void read(_T &f) {
    f = 0; _T fu = 1; char c = getchar();
    while(c < '0' || c > '9') { if(c == '-') fu = -1; c = getchar(); }
    while(c >= '0' && c <= '9') { f = (f << 3) + (f << 1) + (c & 15); c = getchar(); }
    f *= fu;
}

template <typename T>
void print(T x) {
    if(x < 0) putchar('-'), x = -x;
    if(x < 10) putchar(x + 48);
    else print(x / 10), putchar(x % 10 + 48);
}

template <typename T>
void print(T x, char t) {
    print(x); putchar(t);
}

template <typename T>
struct hash_map_t {
    vector <T> v, val, nxt;
    vector <int> head;
    int mod, tot, lastv;
    T lastans;
    bool have_ans;

    hash_map_t (int md = 0) {
        head.clear(); v.clear(); val.clear(); nxt.clear(); tot = 0; mod = md;
        nxt.resize(1); v.resize(1); val.resize(1); head.resize(mod);
        have_ans = 0;
    }

    bool count(int x) {
        int u = x % mod;
        for(register int i = head[u]; i; i = nxt[i]) {
            if(v[i] == x) {
                have_ans = 1;
                lastv = x;
                lastans = val[i];
                return 1;
            }
        }
        return 0;
    }

    void ins(int x, int y) {
        int u = x % mod;
        nxt.push_back(head[u]); head[u] = ++tot;
        v.push_back(x); val.push_back(y);
    }

    int qry(int x) {
        if(have_ans && lastv == x) return lastans;
        count(x);
        return lastans;
    }
};

const int N = 1e5 + 5;

struct ele {
    int h, gs;
    bool operator < (const ele A) const { return h < A.h; }
} a[N];

int ch[N][2], val[N], tag[N], siz[N], n, root;

inline void update(int u) { siz[u] = siz[ch[u][0]] + siz[ch[u][1]] + 1; }

inline void add_tag(int u, int x) { if(u <= 100000) val[u] += x; tag[u] += x; }

inline void pushdown(int u) {
    if(tag[u]) {
        if(ch[u][0]) add_tag(ch[u][0], tag[u]);
        if(ch[u][1]) add_tag(ch[u][1], tag[u]);
        tag[u] = 0;
    }
}

inline void rotate(int &u, int d) {
    int tmp = ch[u][d];
    ch[u][d] = ch[tmp][d ^ 1];
    ch[tmp][d ^ 1] = u;
    update(u); update(tmp);
    u = tmp;
}

void splay(int &u, int k) {
    pushdown(u);
    int ltree = siz[ch[u][0]];
    if(ltree + 1 == k) return;
    int d = k > ltree;
    pushdown(ch[u][d]);
    int k2 = d ? k - ltree - 1 : k;
    int ltree2 = siz[ch[ch[u][d]][0]];
    if(ltree2 + 1 != k2) {
        int d2 = k2 > ltree2;
        splay(ch[ch[u][d]][d2], d2 ? k2 - ltree2 - 1 : k2);
        if(d == d2) rotate(u, d); else rotate(ch[u][d], d2);
    }
    rotate(u, d);
}

int find(int u, int x) {
    pushdown(u);
    if(!u) return 0;
    if(x > val[u]) return siz[ch[u][0]] + 1 + find(ch[u][1], x);
    return find(ch[u][0], x);
}

void insert(int &u, int x, int y) {
    splay(u, x + 1); splay(ch[u][0], x);
    ch[ch[u][0]][1] = y; update(ch[u][0]); update(u);
}

ll ans;
void dfs(int u) {
    if(!u) return;
    pushdown(u);
    dfs(ch[u][0]);
    if(u <= 100000) ans += 1ll * val[u] * (val[u] - 1) / 2;
    dfs(ch[u][1]);
}

int main() {
    read(n);
    for(register int i = 1; i <= n; i++) read(a[i].h), read(a[i].gs);
    sort(a + 1, a + n + 1); int now = 1; root = 1;
    ch[root][0] = 100001; ch[root][1] = 100002;
    val[100001] = -1; val[100002] = 100002;
    update(100001); update(100002); update(root);
    for(register int i = 1; i <= n; i++) {
        while(now < a[i].h) {
            ++now; update(now);
            insert(root, 1, now);
        }
        splay(root, a[i].gs + 2); splay(ch[root][0], a[i].gs + 1);
        int left = find(ch[root][0], val[ch[root][0]]), v = val[ch[root][0]], l = ch[root][0];
        ch[root][0] = 0; update(root);
        int right = find(root, v + 1);
        if(!right) {
            add_tag(l, 1);
            ch[root][0] = l;
            update(root);
        } else {
            // fprintf(stderr, "left = %d\n", left);
            splay(l, left + 1);
            int ll = ch[l][0]; ch[l][0] = 0; update(l);
            add_tag(ll, 1); add_tag(l, 1);
            insert(root, right, l);
            splay(root, 1); ch[root][0] = ll; update(root);
        }
    }
    dfs(root);
    cout << ans << endl;
    return 0;
}
posted @ 2019-03-02 20:29  LJC00118  阅读(257)  评论(0编辑  收藏  举报
/*
*/