splay

先贴个常数巨大的 splay 板子。

namespace splay {
struct node {
    int key, fa, size, cnt, ch[2];
};
node t[maxn];
int tot;

#define ls ch[0]
#define rs ch[1]
#define rt t[0].ch[1]

void push_up(int p) {
    t[p].size = t[t[p].ls].size + t[t[p].rs].size + t[p].cnt;
}
/**
 * @return 0/1 if p is lson/rson
 */
bool get(int p) { return t[t[p].fa].rs == p; }

int newnode(int v, int f) {
    tot++;
    t[tot].fa = f, t[tot].key = v;
    t[tot].cnt = t[tot].size = 1;
    t[tot].ls = t[tot].rs = 0;
    return tot;
}

void connect(int p, int f, int s) {
    t[p].fa = f, t[f].ch[s] = p;
}

/*******************************
 *     G                 G     *
 *     |                 |     *
 *     F     --zig->     P     *
 *    / \               / \    *
 *   P   Q   <-zag--   L   F   *
 *  / \                   / \  *
 * L   R                 R   Q *
 *******************************/

/**
 * @brief pos == 0 ? zig(->) : zag(<-)
 */
void rotate(int p) {
    int f = t[p].fa, g = t[f].fa;
    int pos = get(p), posf = get(f);
    int s = t[p].ch[pos ^ 1]; // node R
    connect(s, f, pos);
    connect(f, p, pos ^ 1);
    connect(p, g, posf);
    push_up(f), push_up(p);
}

void splay(int p, int tar) {
    tar = t[tar].fa;
    while (t[p].fa != tar) {
        int f = t[p].fa;
        if (t[f].fa == tar)
            rotate(p);
        else if (get(p) == get(f))
            rotate(f), rotate(p);
        else
            rotate(p), rotate(p);
    }
}

void insert(int v) {
    int cur = rt;
    if (!cur) {
        rt = newnode(v, 0);
        return;
    }
    while (1) {
        t[cur].size++;
        if (t[cur].key == v) {
            t[cur].cnt++;
            splay(cur, rt);
            return;
        }
        int nxt = v < t[cur].key ? 0 : 1;
        if (!t[cur].ch[nxt]) {
            int p = newnode(v, cur);
            t[cur].ch[nxt] = p;
            splay(p, rt);
            return;
        }
        cur = t[cur].ch[nxt];
    }
}

int find(int v) {
    int cur = rt;
    while (1) {
        if (t[cur].key == v) {
            splay(cur, rt);
            return cur;
        }
        int nxt = v < t[cur].key ? 0 : 1;
        if (!t[cur].ch[nxt])
            return 0;
        cur = t[cur].ch[nxt];
    }
}

void erase(int v) {
    int cur = find(v);
    if (!cur) return;
    if (t[cur].cnt > 1) {
        t[cur].cnt--, t[cur].size--;
        return;
    }
    if (!t[cur].ls && !t[cur].rs)
        rt = 0;
    else if (!t[cur].ls) {
        rt = t[cur].rs;
        t[rt].fa = 0;
    } else {
        int lson = t[cur].ls;
        while (t[lson].rs) lson = t[lson].rs;
        splay(lson, t[cur].ls);
        connect(t[cur].rs, lson, 1);
        connect(lson, 0, 1);
        push_up(lson);
    }
}

int pre(int v) {
    int cur = rt, ans = -inf;
    while (cur) {
        if (t[cur].key < v && t[cur].key > ans)
            ans = t[cur].key;
        int nxt = v > t[cur].key;
        cur = t[cur].ch[nxt];
    }
    return ans;
}

int suf(int v) {
    int cur = rt, ans = inf;
    while (cur) {
        if (t[cur].key > v && t[cur].key < ans)
            ans = t[cur].key;
        int nxt = v >= t[cur].key;
        cur = t[cur].ch[nxt];
    }
    return ans;
}

int rank(int v) {
    int cur = find(v);
    if (cur)
        return t[t[cur].ls].size + 1;
    else {
        int now = suf(v);
        if (now == inf)
            return t[rt].size + 1;
        else
            return rank(now);
    }
}

int kth(int v) {
    if (v > t[rt].size) return inf;
    int cur = rt;
    while (1) {
        int all = t[cur].size - t[t[cur].rs].size;
        if (v > t[t[cur].ls].size && v <= all) {
            splay(cur, rt);
            return t[cur].key;
        }
        if (v < all)
            cur = t[cur].ls;
        else
            v -= all, cur = t[cur].rs;
    }
}

#undef ls
#undef rs
#undef rt
}; // namespace splay
posted @ 2021-10-11 22:06  Theophania  阅读(178)  评论(0)    收藏  举报