[数据结构]替罪羊树简介

  替罪羊树是不通过旋转而是重构来维护节点平衡的一种平衡树。当某一棵子树的节点总数超过其父节点的一定时,就进行重构操作。


节点定义

  为了判断是否需要重构,所以需要加入cover(实际节点个数)域。这次直接加入可重操作,所以还需要增加一个size域。为了体现C++面向对象的思想(分明就是Java用多了),所以判断一个子树是否需用重构写成成员函数bad()。(真开心,因为是重构,不需要那么多的father,终于可以轻松地删掉父节点指针)

 (更正一个错误于此,详见日志)

 1 template<typename T>
 2 class ScapegoatTreeNode{
 3     private:
 4         static const float factora = 0.75;     //平衡因子 
 5     public:
 6         T data;
 7         int size, cover;    //数的个数, 实际节点个数
 8         int count;
 9         ScapegoatTreeNode* next[2];
10         
11         ScapegoatTreeNode():size(1), cover(1), count(1){
12             memset(next, 0, sizeof(next));
13         }
14         
15         ScapegoatTreeNode(T data):data(data), size(1), cover(1), count(1){
16             memset(next, 0, sizeof(next));
17         }
18         
19         void maintain(){    //维护大小 
20             cover = 1, size = count;
21             for(int i = 0; i < 2; i++)
22                 if(next[i] != NULL)
23                     cover += next[i]->cover, size += next[i]->size; 
24         }
25         
26         int cmp(T data){
27             if(this->data == data)    return -1;
28             return (data < this->data) ? (0) : (1);
29         }
30         
31         boolean bad(){
32             for(int i = 0; i < 2; i++)
33 //                if(next[i] != NULL && next[i]->cover > this->cover * factora)
34                 if(next[i] != NULL && next[i]->cover > (this->cover + 3) * factora)        //+1 384ms +3 376ms +5 448ms
35                     return true;
36             return false;        
37         }
38         
39         inline void addCount(int val){
40             size += val;
41             count += val;
42         }
43 };

 


重构操作

  首先将要重构的子树拍平(遍历一次得到一个数组),然后利用这个数组进行重构。每次的节点为这个区间的中点,然后递归调用去把[l, mid - 1]重构左子树,[mid + 1, r]重构有子树。记得在每层函数返回之前进行子树大小的维护。

  可能这段文字不好帮助理解,那就把上面那棵树重构了(虽然它很平衡,但是长得太丑了)。首先中序遍历得到一个有序的数组(由于我比较懒,所以用的vector,建议保存节点的地址或下标,不要只保存数据)

  找到mid,然后递归生成它的左子树和右子树:

  为了体现当l > r时,直接return NULL,特此注明的键值为7的左子树。

 1 //查找操作,因为重构后,有一些节点会消失,需要重新维护一下cover
 2 static ScapegoatTreeNode<T>* find(ScapegoatTreeNode<T>*& node, T data){
 3     if(node == NULL)    return NULL;
 4     int d = node->cmp(data);
 5     if(d == -1)        return node;
 6     ScapegoatTreeNode<T>* s = find(node->next[d], data);
 7     node->maintain();
 8     return s;
 9 }
10 
11 //中序遍历得到一个数组
12 void getLis(ScapegoatTreeNode<T>*& node){
13     if(node == NULL)    return;
14     getLis(node->next[0]);
15     if(node->count > 0)    lis.push_back(node);
16     getLis(node->next[1]);
17     if(node->count <= 0)    delete node;
18 }
19 
20 //重构的主要过程
21 ScapegoatTreeNode<T>* rebuild(int from, int end){
22      if(from > end)    return NULL;
23     if(from == end){ 
24         ScapegoatTreeNode<T>* node = lis[from];
25         node->next[0] = node->next[1] = NULL;
26         node->size = node->count;
27         node->cover = 1;
28         return node; 
29     }
30     int mid = (from + end) >> 1;
31     ScapegoatTreeNode<T>* node = lis[mid];
32     node->next[0] = rebuild(from, mid - 1);
33     node->next[1] = rebuild(mid + 1, end);
34     node->maintain();
35     return node;
36 }
37 
38 //调用
39 void rebuild(ScapegoatTreeNode<T>*& node, ScapegoatTreeNode<T>*& father){
40     lis.clear();
41     getLis(node);
42     ScapegoatTreeNode<T>* ret = rebuild(0, (unsigned)lis.size() - 1);
43     if(father == NULL)    root = ret;
44     else{
45         father->next[father->cmp(ret->data)] = ret;
46         find(root, ret->data);
47     }
48 }

插入操作

  插入操作还是按照BST的性质进行插入。只不过中途要找到最后一个极其不平衡的子树的根节点(试想一下,一遇到有问题的就重构,那么替罪羊树的复杂度会变为多少)。这个节点又叫『替罪羊节点』。其它就没有可以多说的内容了。

 1 static boolean remove(ScapegoatTreeNode<T>*& node, T data){
 2     if(node == NULL)    return false;
 3     int d = node->cmp(data);
 4     if(d == -1){
 5         node->addCount(-1);
 6         return true;
 7     }
 8     boolean res = remove(node->next[d], data);
 9     if(res)    node->maintain();
10     return res;
11 }
12 
13 boolean remove(T data){
14     boolean res = remove(root, data);
15     return res;
16 }

其它各种操作

·各种bound

  思路可以参照Splay中的思路,唯一注意一点,如果当前的节点不存在,且按照cmp指示的方向并不存在,那么就得向另外一个方向来找(之前被坑了好多好多次,除非本来就要比这个数据大,然而在右子树中没找到,就这个意思,理解了就好)。

  下面以lower_bound为例:

1 static ScapegoatTreeNode<T>* upper_bound(ScapegoatTreeNode<T>*& node, T val){
2     if(node == NULL)    return node;
3     int to = node->cmp(val);
4     if(val == node->data)    to = 1;
5     ScapegoatTreeNode<T>* ret = upper_bound(node->next[to], val);
6     if(to == 0 && ret == NULL)    ret = upper_bound(node->next[1], val);
7     return ((ret == NULL || node->data < ret->data) && node->data > val && node->count > 0) ? (node) : (ret);
8 }

·名次操作(没有变动,参照Splay内的名次操作)


完整代码和总结

  替罪羊树的思路可以算是我见过的平衡树中最简单的,然而实现起来处处被坑,处处RE。另外可以通过多次实践来调整平衡因子和bad()函数,可以不通过改变主体过程就可以做到提高效率。下面是bzoj 3224的AC完整代码。

  1 /**
  2  * bzoj
  3  * Problem#3224
  4  * Accepted
  5  * Time:500ms / 368ms
  6  * Memory:2628k / 2628k
  7  */
  8 #include<iostream>
  9 #include<fstream>
 10 #include<sstream>
 11 #include<cstdio>
 12 #include<cstdlib>
 13 #include<cstring>
 14 #include<ctime>
 15 #include<cctype>
 16 #include<cmath>
 17 #include<algorithm>
 18 #include<stack>
 19 #include<queue>
 20 #include<set>
 21 #include<map>
 22 #include<vector>
 23 using namespace std;
 24 typedef bool boolean;
 25 #define smin(a, b)    (a) = min((a), (b))
 26 #define smax(as, b)    (a) = max((a), (b))
 27 template<typename T>
 28 inline boolean readInteger(T& u){
 29     char x;
 30     int aFlag = 1;
 31     while(!isdigit((x = getchar())) && x != '-' && x != -1);
 32     if(x == -1)    return false;
 33     if(x == '-'){
 34         x = getchar();
 35         aFlag = -1;
 36     }
 37     for(u = x - '0'; isdigit((x = getchar())); u = (u << 3) + (u << 1) + x - '0');
 38     ungetc(x, stdin);
 39     u *= aFlag;
 40     return true;
 41 }
 42 
 43 template<typename T>
 44 class ScapegoatTreeNode{
 45     private:
 46         static const float factora = 0.75;     //平衡因子 
 47     public:
 48         T data;
 49         int size, cover;    //数的个数, 实际节点个数
 50         int count;
 51         ScapegoatTreeNode* next[2];
 52         
 53         ScapegoatTreeNode():size(1), cover(1), count(1){
 54             memset(next, 0, sizeof(next));
 55         }
 56         
 57         ScapegoatTreeNode(T data):data(data), size(1), cover(1), count(1){
 58             memset(next, 0, sizeof(next));
 59         }
 60         
 61         void maintain(){    //维护大小 
 62             cover = 1, size = count;
 63             for(int i = 0; i < 2; i++)
 64                 if(next[i] != NULL)
 65                     cover += next[i]->cover, size += next[i]->size; 
 66         }
 67         
 68         int cmp(T data){
 69             if(this->data == data)    return -1;
 70             return (data < this->data) ? (0) : (1);
 71         }
 72         
 73         boolean bad(){
 74             for(int i = 0; i < 2; i++)
 75 //                if(next[i] != NULL && next[i]->cover > this->cover * factora)
 76                 if(next[i] != NULL && next[i]->cover > (this->cover + 3) * factora)        //+1 384ms +3 376ms +5 448ms
 77                     return true;
 78             return false;        
 79         }
 80         
 81         inline void addCount(int val){
 82             size += val;
 83             count += val;
 84         }
 85 };
 86 
 87 template<typename T>
 88 class ScapegoatTree{
 89     protected:
 90         static void insert(ScapegoatTreeNode<T>*& node, T data, ScapegoatTreeNode<T>*& last, ScapegoatTreeNode<T>*& father){
 91             if(node == NULL){
 92                 node = new ScapegoatTreeNode<T>(data);
 93                 return;
 94             }
 95             int d = node->cmp(data);
 96             if(d == -1){
 97                 node->addCount(1);
 98                 return;
 99             }
100             insert(node->next[d], data, last, father);
101             node->maintain();
102             if(father == NULL && last != NULL)    father = node;
103             if(node->bad())    last = node, father = NULL;
104         }
105         
106         static boolean remove(ScapegoatTreeNode<T>*& node, T data){
107             if(node == NULL)    return false;
108             int d = node->cmp(data);
109             if(d == -1){
110                 node->addCount(-1);
111                 return true;
112             }
113             boolean res = remove(node->next[d], data);
114             if(res)    node->maintain();
115             return res;
116         }
117         
118         static ScapegoatTreeNode<T>* less_bound(ScapegoatTreeNode<T>*& node, T val){
119             if(node == NULL)    return node;
120             int to = node->cmp(val);
121             if(val == node->data)    to = 0;
122             ScapegoatTreeNode<T>* ret = less_bound(node->next[to], val);
123             if(to == 1 && ret == NULL)    ret = less_bound(node->next[0], val);
124             return ((ret == NULL || node->data > ret->data) && node->data < val && node->count > 0) ? (node) : (ret);
125         }
126         
127         static ScapegoatTreeNode<T>* upper_bound(ScapegoatTreeNode<T>*& node, T val){
128             if(node == NULL)    return node;
129             int to = node->cmp(val);
130             if(val == node->data)    to = 1;
131             ScapegoatTreeNode<T>* ret = upper_bound(node->next[to], val);
132             if(to == 0 && ret == NULL)    ret = upper_bound(node->next[1], val);
133             return ((ret == NULL || node->data < ret->data) && node->data > val && node->count > 0) ? (node) : (ret);
134         }
135         
136         static ScapegoatTreeNode<T>* findKth(ScapegoatTreeNode<T>*& node, int k){
137             int ls = (node->next[0] == NULL) ? (0) : (node->next[0]->size);
138             int count = node->count;
139             if(k >= ls + 1 && k <= ls + count)    return node;
140             if(k <= ls)    return findKth(node->next[0], k);
141             return findKth(node->next[1], k - ls - count);
142         }
143         
144         static ScapegoatTreeNode<T>* find(ScapegoatTreeNode<T>*& node, T data){
145             if(node == NULL)    return NULL;
146             int d = node->cmp(data);
147             if(d == -1)        return node;
148             ScapegoatTreeNode<T>* s = find(node->next[d], data);
149             node->maintain();
150             return s;
151         }
152     public:
153         ScapegoatTreeNode<T>* root;
154         vector<ScapegoatTreeNode<T>*> lis;
155         
156         ScapegoatTree():root(NULL){    }
157         
158         void getLis(ScapegoatTreeNode<T>*& node){
159             if(node == NULL)    return;
160             getLis(node->next[0]);
161             if(node->count > 0)    lis.push_back(node);
162             getLis(node->next[1]);
163             if(node->count <= 0)    delete node;
164         }
165         
166         ScapegoatTreeNode<T>* rebuild(int from, int end){
167              if(from > end)    return NULL;
168             if(from == end){
169                 ScapegoatTreeNode<T>* node = lis[from];
170                 node->next[0] = node->next[1] = NULL;
171                 node->size = node->count;
172                 node->cover = 1;
173                 return node; 
174             }
175             int mid = (from + end) >> 1;
176             ScapegoatTreeNode<T>* node = lis[mid];
177             node->next[0] = rebuild(from, mid - 1);
178             node->next[1] = rebuild(mid + 1, end);
179             node->maintain();
180             return node;
181         }
182         
183         void rebuild(ScapegoatTreeNode<T>*& node, ScapegoatTreeNode<T>*& father){
184             lis.clear();
185             getLis(node);
186             ScapegoatTreeNode<T>* ret = rebuild(0, (unsigned)lis.size() - 1);
187             if(father == NULL)    root = ret;
188             else{
189                 father->next[father->cmp(ret->data)] = ret;
190                 find(root, ret->data);
191             }
192         }
193         
194         void insert(T data){
195             ScapegoatTreeNode<T>* node = NULL, *father = NULL;
196             insert(root, data, node, father);
197             if(node != NULL)    rebuild(node, father);
198         }
199         
200         boolean remove(T data){
201             boolean res = remove(root, data);
202             return res;
203         }
204         
205         void out(ScapegoatTreeNode<T>*& node){    //调试使用函数,打印树 
206             if(node == NULL)    return;
207             cout << node->data << "(" << node->size << "," << node->cover << "," << node->count << "){";
208             out(node->next[0]);
209             cout <<    ",";
210             out(node->next[1]);
211             cout << "}";
212         }
213         
214         ScapegoatTreeNode<T>* less_bound(T data){
215             return less_bound(root, data);
216         }
217         
218         ScapegoatTreeNode<T>* upper_bound(T data){
219             return upper_bound(root, data);
220         }
221         
222         ScapegoatTreeNode<T>* findKth(int k){
223             return findKth(root, k);
224         }
225         
226         inline int rank(T data){
227             ScapegoatTreeNode<T>* p = root;
228             int r = 0;
229             while(p != NULL){
230                 int ls = (p->next[0] != NULL) ? (p->next[0]->size) : (0);
231                 if(p->data == data)    return r + ls + 1;
232                 int d = p->cmp(data);
233                 if(d == 1)    r += ls + p->count;
234                 p = p->next[d];
235             }
236             return r + 1;
237         }
238 };
239 
240 int n;
241 ScapegoatTree<int> s;
242 
243 void printTree(){
244     s.out(s.root);
245     putchar('\n');
246 }
247 
248 inline void solve(){
249     int opt, x;    
250     readInteger(n);
251     while(n--){
252         readInteger(opt);
253         readInteger(x);
254         if(opt == 1)    s.insert(x);
255         else if(opt == 2) s.remove(x);
256         else if(opt == 3)    printf("%d\n", s.rank(x));
257         else if(opt == 4) printf("%d\n", s.findKth(x)->data);
258         else if(opt == 5)    printf("%d\n", s.less_bound(x)->data);
259         else printf("%d\n", s.upper_bound(x)->data);
260     }
261 }
262 
263 int main(){
264     solve();
265     return 0;
266 }

[日志]

  2017-1-24,更正1错误,不应该更新cover的值,cover是实际结点的个数

        2017-7-18,更正2处代码中注释中的错别字

posted @ 2017-01-18 22:07  阿波罗2003  阅读(662)  评论(0编辑  收藏  举报