Scape - Goat Tree

Scape - Goat Tree

主打的就是一个定期重构,惰性删除。

通过一个节点的平衡性从而判断是否重构整棵树,我想那一个节点就是 雪风大人 替罪羊吧。

该文的替罪羊树十分暴力,甚至不清楚时间复杂度对不对。

来,一键护航。

Basis

Definition

struct Yukikaze
{
    struct ScapeGoatTree
    {
        int son[2]; // 左右子节点编号。0表示左子节点,1表示右子节点。
        int val; // 该点权值
        int sz; // 子树大小(包含重复数字)
        int cnt; // 相同权值的数的个数
    }tr[N];
    int root, tot;
    int cur[N], curcnt; // 重构数组
}

Push up

void pushup(int p)
{
    tr[p].sz = tr[tr[p].son[0]].sz + tr[tr[p].son[1]].sz + tr[p].cnt;
}

Check

任意一边的子树过大了,就暴力重构。

bool is_bad(int u)
{
    if(1.0 * tr[u].sz * Alpha < 1.0 * tr[tr[u].son[0]].sz)
        return true;
    if(1.0 * tr[u].sz * Alpha < 1.0 * tr[tr[u].son[1]].sz)
        return true;
    return false;
}

Mid-travel

void midtravel(int u) // 中序遍历,存需要重构的点
{
    if(!u) return;
    midtravel(tr[u].son[0]);
    if(tr[u].cnt) cur[++curcnt] = u; 
    // tr[u].cnt == 0 就不用被建到新树上了。这里是真正意义地被删除出平衡树。
    midtravel(tr[u].son[1]);
}

中序遍历还有个很实用的作用是 debug

void print(int u)
{
    if(!u) return;
    print(tr[u].son[0]);
    if(tr[u].cnt) printf("%d %d %d %d\n", u, tr[u].val, tr[u].son[0], tr[u].son[1]);
    print(tr[u].son[1]);
}

build & rebuild

int build(int l, int r)
{
    if(l > r) return 0; // 注意不能写等号,否则叶节点无法被加入到新树中
    int mid = l + (r - l) / 2;
    // 以下注意 cur[mid] 的地方不要写成 mid
    tr[cur[mid]].son[0] = build(l ,mid - 1);
    tr[cur[mid]].son[1] = build(mid + 1, r);
    pushup(cur[mid]);
    return cur[mid];
}
void rebuild(int &u)
{
    curcnt = 0;
    midtravel(u);
    u = build(1, curcnt);
}

Insert

void insert(int &u, int x)
{
    if(!u)
    {
        u = ++tot;
        tr[u].val = x;
        tr[u].son[0] = tr[u].son[1] = 0;
        tr[u].sz = tr[u].cnt = 1;
        return;
    }
    if(x == tr[u].val) tr[u].cnt++;
    else if(x < tr[u].val) 
        insert(tr[u].son[0], x);
    else if(x > tr[u].val)
        insert(tr[u].son[1], x);
    pushup(u);
    if(is_bad(u))
        rebuild(u);
    return;
}

Remove

void remove(int &u, int x)
{
    if(!u) return;
    if(x == tr[u].val)
    {
        if(tr[u].cnt)
            tr[u].cnt--;
    }
    else if(x < tr[u].val)
        remove(tr[u].son[0], x);
    else if(x > tr[u].val)
        remove(tr[u].son[1], x);
    pushup(u);
    if(is_bad(u))
        rebuild(u);
    return;
}

Get rank by value

int get_rank_by_val(int u, int x)
{
    if(!u) return 1; // 注意返回 "+1"
    if(x == tr[u].val && tr[u].cnt) // 注意判断该点是否已经被惰性删除了
        return tr[tr[u].son[0]].sz + 1;
    if(x < tr[u].val)
        return get_rank_by_val(tr[u].son[0], x);
    return tr[tr[u].son[0]].sz + tr[u].cnt + get_rank_by_val(tr[u].son[1], x);
}

Get value by rank

int get_val_by_rank(int u, int x)
{
    if(!u) return 0;
    if(x > tr[tr[u].son[0]].sz && x <= tr[tr[u].son[0]].sz + tr[u].cnt) // && tr[u].cnt
        return tr[u].val; // 如果 tr[u].cnt == 0,则这个 if 一定不成立,因此可以不用判断
    if(x <= tr[tr[u].son[0]].sz)
        return get_val_by_rank(tr[u].son[0], x);
    return get_val_by_rank(tr[u].son[1], x - tr[tr[u].son[0]].sz - tr[u].cnt);
}

Get last & Get next

int get_lstval(int x)
{
    int rank = get_rank_by_val(root, x) - 1; 
    // 这里并不是排名的真正定义,只是为了找前驱的值而非前驱的排名
    return get_val_by_rank(root, rank);
}
int get_nxtval(int x)
{
    int rank = get_rank_by_val(root, x + 1);
    // 这个是排名的真正定义,且通过排名找到了后继的值
    return get_val_by_rank(root, rank);
}

注意有这样一种写法是 错误 的:

    int get_lstval(int x)
    {
        int u = root, res = -INF;
        while(u)
        {
            if(tr[u].val < x && tr[u].cnt)
                res = tr[u].val;
            u = tr[u].son[tr[u].val < x];
        }
        return res;
    }
    int get_nxtval(int x)
    {
        int u = root, res = INF;
        while(u)
        {
            if(tr[u].val > x && tr[u].cnt)
                res = tr[u].val;
            u = tr[u].son[tr[u].val <= x];
        }
        return res;
    }

原因是从被惰性删除但还没有完全删除的节点往下跳,可能出现问题。

比如:

现在 get_lstval(x)1, 2 号节点被删除。此时如果 5 号节点的值大于 6 号节点的值,则查询结果就出错了。

上述算法会从 1 往左跳,查询到 6 号节点并返回 6 号节点的值,但答案应为 5 号节点的值。

查询后继同样会遇到类似的问题。

Final Code

#include<bits/stdc++.h>
using namespace std;
const double Alpha = 0.75;
const int N = 2e6 + 5;
const int INF = INT_MAX;
struct Yukikaze
{
    struct ScapeGoatTree
    {
        int son[2];
        int val;
        int sz, cnt;
    }tr[N];
    int root, tot;
    int cur[N], curcnt;
    void pushup(int p)
    {
        tr[p].sz = tr[tr[p].son[0]].sz + tr[tr[p].son[1]].sz + tr[p].cnt;
    }
    bool is_bad(int u)
    {
        if(1.0 * tr[u].sz * Alpha < 1.0 * tr[tr[u].son[0]].sz)
            return true;
        if(1.0 * tr[u].sz * Alpha < 1.0 * tr[tr[u].son[1]].sz)
            return true;
        return false;
    }
    void midtravel(int u)
    {
        if(!u) return;
        midtravel(tr[u].son[0]);
        if(tr[u].cnt) cur[++curcnt] = u;
        midtravel(tr[u].son[1]);
    }
    void print(int u)
    {
        if(!u) return;
        print(tr[u].son[0]);
        if(tr[u].cnt) printf("%d %d %d %d\n", u, tr[u].val, tr[u].son[0], tr[u].son[1]);
        print(tr[u].son[1]);
    }
    int build(int l, int r)
    {
        if(l > r) return 0;
        int mid = l + (r - l) / 2;
        tr[cur[mid]].son[0] = build(l ,mid - 1);
        tr[cur[mid]].son[1] = build(mid + 1, r);
        pushup(cur[mid]);
        return cur[mid];
    }
    void rebuild(int &u)
    {
        curcnt = 0;
        midtravel(u);
        u = build(1, curcnt);
    }
    void insert(int &u, int x)
    {
        if(!u)
        {
            u = ++tot;
            tr[u].val = x;
            tr[u].son[0] = tr[u].son[1] = 0;
            tr[u].sz = tr[u].cnt = 1;
            return;
        }
        if(x == tr[u].val) tr[u].cnt++;
        else if(x < tr[u].val) 
            insert(tr[u].son[0], x);
        else if(x > tr[u].val)
            insert(tr[u].son[1], x);
        pushup(u);
        if(is_bad(u))
            rebuild(u);
        return;
    }
    void remove(int &u, int x)
    {
        if(!u) return;
        if(x == tr[u].val)
        {
            if(tr[u].cnt)
                tr[u].cnt--;
        }
        else if(x < tr[u].val)
            remove(tr[u].son[0], x);
        else if(x > tr[u].val)
            remove(tr[u].son[1], x);
        pushup(u);
        if(is_bad(u))
            rebuild(u);
        return;
    }
    int get_val_by_rank(int u, int x)
    {
        if(!u) return 0;
        if(x > tr[tr[u].son[0]].sz && x <= tr[tr[u].son[0]].sz + tr[u].cnt) // && tr[u].cnt
            return tr[u].val;
        if(x <= tr[tr[u].son[0]].sz)
            return get_val_by_rank(tr[u].son[0], x);
        return get_val_by_rank(tr[u].son[1], x - tr[tr[u].son[0]].sz - tr[u].cnt);
    }
    int get_rank_by_val(int u, int x)
    {
        if(!u) return 1;
        if(x == tr[u].val && tr[u].cnt)
            return tr[tr[u].son[0]].sz + 1;
        if(x < tr[u].val)
            return get_rank_by_val(tr[u].son[0], x);
        return tr[tr[u].son[0]].sz + tr[u].cnt + get_rank_by_val(tr[u].son[1], x);
    }
    int get_lstval(int x)
    {
        int rank = get_rank_by_val(root, x) - 1;
        return get_val_by_rank(root, rank);
    }
    int get_nxtval(int x)
    {
        int rank = get_rank_by_val(root, x + 1);
        return get_val_by_rank(root, rank);
    }
}swd; // snow wind. - destroyer!!!
int n, Q, lastans, ans;
int main()
{
    scanf("%d %d", &n, &Q);
    for(int i = 1; i <= n; ++i)
    {
        int x; scanf("%d", &x);
        swd.insert(swd.root, x);
    }
    swd.insert(swd.root, INF);
    swd.insert(swd.root, -INF);
    while(Q--)
    {
        int op, x;
        scanf("%d %d", &op, &x);
        x ^= lastans;
        if(op == 1) swd.insert(swd.root, x);
        if(op == 2) swd.remove(swd.root, x);
        if(op == 3) lastans = swd.get_rank_by_val(swd.root, x) - 1;
        if(op == 4) lastans = swd.get_val_by_rank(swd.root, x + 1);
        if(op == 5) lastans = swd.get_lstval(x);
        if(op == 6) lastans = swd.get_nxtval(x);
        if(op > 2) ans ^= lastans;
        // swd.print(swd.root); puts("");
    }
    printf("%d\n", ans);
    return 0;
}
posted @ 2023-01-29 22:32  Schucking_Sattin  阅读(26)  评论(0)    收藏  举报