P5076 【深基16.例7】普通二叉树(简化版)

只要坑不死,就往死里坑呗。。。
二叉搜索树的模板题,本题无删除操作

核心函数:

  1. query_by_value(int root, int val): 根据值来查找排名
  2. query_by_rank(int root, int rk): 根据排名找值
  3. insert(int root, int val): 向树中插入val
  4. query_pre(int root, int val): 找值val的前驱(最大的小于val的值)
  5. query_suc(int root, int val): 找值val的后继(最小的大于val的值)

坑点:

  1. 可能会插入多个相同的值:为每一个树的结点设置cnt
  2. 在根据value来找rank的时候,可能树中不存在value,但是同样需要输出它的rank
  3. 根据rank找value时,如果找不到对应rank应该输出INF

平均查询复杂度:\(O(logn)\)
平均插入复杂度:\(O(logn)\)
平均构造一个n个结点的BST的复杂度:\(O(nlogn)\)

最坏情况下的插入一个值复杂度:\(O(n)\)
最坏情况下的查询一个值复杂度:\(O(n)\)
最坏情况下产生一个n个结点的BST的复杂度:\(O(n^2)\), 比如通过一个降序的序列通过插入的方法来构造BST

代码

#include<iostream>

using namespace std;

const int N = 40010, INF = 2147483647;

struct Node{
    int val;
    int l, r;
    int cnt;
    int size;
}tr[N];

int ss = 1;

void insert(int x, int val){
    tr[x].size ++;
    
    if(ss == 1){
        tr[x].val = val;
        tr[x].cnt ++;
        ss ++;
        return;
    }
    if(val == tr[x].val) tr[x].cnt ++;
    else if(val > tr[x].val){
        if(tr[x].r) insert(tr[x].r, val);
        else{
            tr[x].r = ss;
            tr[ss].val = val;
            tr[ss].cnt ++;
            tr[ss].size ++;
            ss ++;
        }
    }else if(val < tr[x].val){
        if(tr[x].l) insert(tr[x].l, val);
        else{
            tr[x].l = ss;
            tr[ss].val = val;
            tr[ss].cnt ++;
            tr[ss].size ++;
            ss ++;
        }
    }
}

int query_by_value(int x, int val){
    if(val < tr[x].val){
        if(tr[x].l) return query_by_value(tr[x].l, val);
        return 1;
    }
    if(val > tr[x].val){
        if(tr[x].r) return query_by_value(tr[x].r, val) + tr[tr[x].l].size + tr[x].cnt;
        return tr[tr[x].l].size + tr[x].cnt + 1;
    }
    
    return tr[tr[x].l].size + 1;
}

int query_by_rank(int x, int rk){
    if(x == 0) return INF;
    if(rk <= tr[tr[x].l].size) return query_by_rank(tr[x].l, rk);
    if(rk > tr[tr[x].l].size + tr[x].cnt) return query_by_rank(tr[x].r, rk - tr[tr[x].l].size - tr[x].cnt);
    if(rk > tr[tr[x].l].size && rk <= tr[tr[x].l].size + tr[x].cnt) return tr[x].val;
}

int query_pre(int x, int val){
    if(tr[x].val < val){
        if(tr[x].r) return max(tr[x].val, query_pre(tr[x].r, val));
        return tr[x].val;
    }
    if(tr[x].val >= val){
        if(tr[x].l) return query_pre(tr[x].l, val);
        return -INF;
    }
}

int query_suc(int x, int val){
    if(tr[x].val > val){
        if(tr[x].l) return min(tr[x].val, query_suc(tr[x].l, val));
        return tr[x].val;
    }
    if(tr[x].val <= val){
        if(tr[x].r) return query_suc(tr[x].r, val);
        return INF;
    }
}

int main(){
    int q;
    
    cin >> q;
    
    while(q --){
        int k, x;
        
        cin >> k >> x;
        
        switch(k){
            case 1: cout << query_by_value(1, x) << endl; break;
            case 2: cout << query_by_rank(1, x) << endl; break;
            case 3: cout << query_pre(1, x) << endl; break;
            case 4: cout << query_suc(1, x) << endl; break;
            case 5: insert(1, x); break;
        }
    }
    
    return 0;
}

2020.12.17编辑: 增加删除操作

  1. 根据值删除:del_val(int u, int val)
  2. 根据排名删除:del_rk(int u, int rk),按排名删除可以通过query_by_rank(int u, int rk)转化为根据值删除。

需要注意的是删除需要分成三种情况:

  1. 删除的结点只有左子树
  2. 删除的结点只有右子树
  3. 删除的结点有左右子树

1、2两种直接将左/右子树上拉即可,第三种结点(设为u)的处理方法:找v(最大的比u的值小的结点 <=> 左子树的最右结点) 或(最小的比u大的结点 <=> 右子树的最左结点),将其值和cnt赋值给u,然后将v删除,将v删除必定属于前两种情况。

注意:方便起见,约定不存在可能导致根结点被删除的操作。如果根结点可能被删除,那么就得自己另外搞一个初始根结点(树为空的时候就存在,它的值是一个比所有插入的数都小的数),那样输出排名的时候需要-1,另外查询排名的时候要先+1。

#include<iostream>
using namespace std;

const int N = 40010, INF = 2147483647;

struct Node{
    int val;
    int cnt; // 当前结点存储的相同的val的个数
    int l, r;
    int size; // 以当前结点为根的树的大小
}tr[N];

int ss = 1;

void insert(int u, int val){
    tr[u].size ++;
    
    if(ss == 1){
        tr[u].val = val;
        tr[u].cnt ++;
        ss ++;
        return;
    }
    
    if(val == tr[u].val) tr[u].cnt ++;
    else if(val < tr[u].val){
        if(tr[u].l) insert(tr[u].l, val);
        else{
            tr[u].l = ss;
            tr[ss].val = val;
            tr[ss].size ++;
            tr[ss].cnt ++;
            ss ++;
        }
    }else if(val > tr[u].val){
        if(tr[u].r) insert(tr[u].r, val);
        else{
            tr[u].r = ss;
            tr[ss].val = val;
            tr[ss].size ++;
            tr[ss].cnt ++;
            ss ++;
        }
    }
}

int del_rk(int u, int rk){ // 找rk在树上的位置
    if(u == 0) return 0;
    
    int l = tr[tr[u].l].size + 1;
    int r = l + tr[u].cnt - 1;
    
    // cout << tr[u].val << endl;
    // cout << l << ' ' << r << endl;
    
    if(rk >= l && rk <= r){
        tr[u].cnt --;
        tr[u].size --;
        if(tr[u].cnt) return 1;
        if(tr[u].l && tr[u].r){
            int p = tr[u].l, q = p;
            while(tr[p].r) p = tr[p].r;
            
            tr[u].val = tr[p].val;
            tr[u].cnt = tr[p].cnt;
            
            p = q;
            while(tr[p].r){
                q = p;
                tr[p].size -= tr[u].cnt;
                p = tr[p].r;
            }
            
            if(p == q) tr[u].l = 0;
            else tr[q].r = tr[p].l;
        }
        
        return 1;
    }
    
    if(rk < l && del_rk(tr[u].l, rk)){
        tr[u].size --;
        int l = tr[u].l;
        if(tr[l].cnt == 0) tr[u].l = tr[l].l ? tr[l].l : (tr[l].r ? tr[l].r : 0);
        return 1;
    }
    
    if(rk > r && del_rk(tr[u].r, rk - l)){
        tr[u].size --;
        int r = tr[u].r;
        if(tr[r].cnt == 0) tr[u].r = tr[r].l ? tr[r].l : (tr[r].r ? tr[r].r : 0);
        return 1;
    }
    
    
    return 0;
}

int del_val(int u, int val){ // 找val在树上的位置
    if(u == 0) return 0;
    
    if(val == tr[u].val){
        tr[u].cnt --;
        tr[u].size --;
        if(tr[u].cnt) return 1;
        if(tr[u].l && tr[u].r){
            int p = tr[u].l, q = p;
            while(tr[p].r) p = tr[p].r;
            
            tr[u].val = tr[p].val;
            tr[u].cnt = tr[p].cnt;
            
            p = q;
            while(tr[p].r){
                q = p;
                tr[p].size -= tr[u].cnt;
                p = tr[p].r;
            }
            
            if(p == q) tr[u].l = 0;
            else tr[q].r = tr[p].l;
        }
        
        return 1;
    }
    
    if(val > tr[u].val && del_val(tr[u].r, val)){
        tr[u].size --;
        int r = tr[u].r;
        if(tr[r].cnt == 0) tr[u].r = tr[r].l ? tr[r].l : (tr[r].r ? tr[r].r : 0);
        return 1;
    }
    
    if(val < tr[u].val && del_val(tr[u].l, val)){
        tr[u].size --;
        int l = tr[u].l;
        if(tr[l].cnt == 0) tr[u].l = tr[l].l ? tr[l].l : (tr[l].r ? tr[l].r : 0);
        return 1;
    }
    
    return 0;
}

int query_rk(int u, int val){
    if(u == 0) return 1;
    if(val == tr[u].val) return tr[tr[u].l].size + 1;
    if(val < tr[u].val) return query_rk(tr[u].l, val);
    return tr[tr[u].l].size + tr[u].cnt + query_rk(tr[u].r, val);
}

int query_val(int u, int rk){
    if(u == 0) return INF;
    
    int l = tr[tr[u].l].size + 1;
    int r = l + tr[u].cnt - 1;
    
    if(rk < l) return query_val(tr[u].l, rk);
    if(rk > r) return query_val(tr[u].r, rk - r);
    return tr[u].val;
}

int query_suc(int u, int val){
    if(u == 0) return INF;
    if(val >= tr[u].val) return query_suc(tr[u].r, val);
    return min(tr[u].val, query_suc(tr[u].l, val));
}

int query_pre(int u, int val){
    if(u == 0) return -INF;
    if(val <= tr[u].val) return query_pre(tr[u].l, val);
    return max(tr[u].val, query_pre(tr[u].r, val));
}

int main(){
    int q;
    
    cin >> q;
    
    while(q --){
        int k, x;
        
        cin >> k >> x;
        
        switch(k){
            case 1: cout << query_rk(1, x) << endl; break;
            case 2: cout << query_val(1, x) << endl; break;
            case 3: cout << query_pre(1, x) << endl; break;
            case 4: cout << query_suc(1, x) << endl; break;
            case 5: insert(1, x); break;
            case 6: 
                if(del_val(1, x)) printf("delete succeeded, size: %d\n", tr[1].size);
                else puts("delete failed");    
                break;
            case 7: 
                if(del_rk(1, x)) printf("delete succeeded, size: %d\n", tr[1].size);
                else puts("delete failed");
                break;
        }
    }
    
    return 0;
}


/*
    10
    5 5
    5 3
    5 6
    5 2
    5 4
    1 5
    7 1
    7 2
    6 5
    1 6
*/

/*
    13
    5 5
    5 3
    5 6
    5 2
    5 4
    1 5
    2 1
    7 1
    2 2
    7 1
    7 1
    6 4
    1 -100
*/

上面的代码:口胡最坏单次删除复杂度为\(O(n)\), 最好情况下删除一个叶子复杂度\(O(logn)\)

posted @ 2020-10-22 19:38  yys_c  阅读(198)  评论(0编辑  收藏  举报