bzoj3224 普通平衡树

 

这道题我先写了值域线段树,一直WA,和网上的标程对拍,也拍不出错误,然后改写SPALY,又WA,也拍不出错误,最后只能用vector水过了。

我把我写的值域线段树、Splay、vector、数据生成器放在下面,若有哪位好心人愿意帮我看看,感激不尽。

 

 

值域线段树:

  1 /**************************************************************
  2 Problem: 3224
  3 User: idy002
  4 Language: C++
  5 Result: Wrong_Answer
  6  ****************************************************************/
  7 
  8 #include <cstdio>
  9 #define fprintf(...)
 10 #define maxn 7000000
 11 #define minv 0
 12 #define inc 10000000
 13 #define maxv 20000000
 14 
 15 int son[maxn][2], siz[maxn], ntot;
 16 
 17 int newnode() {
 18     int nd = ++ntot;
 19     son[nd][0] = son[nd][1] = siz[nd] = 0;
 20     return nd;
 21 }
 22 void pushup( int nd ) {
 23     siz[nd] = siz[son[nd][0]]+siz[son[nd][1]];
 24 }
 25 void insert( int v, int nd, int lf, int rg ) {
 26     fprintf( stderr, "insert( %d %d %d %d )\n", v, nd, lf, rg );
 27     if( lf==rg ) {
 28         siz[nd]++;
 29         return;
 30     }
 31     int mid = (lf+rg)>>1;
 32     if( !son[nd][ v>mid ] ) son[nd][ v>mid ] = newnode();
 33     if( v<=mid ) insert( v, son[nd][0], lf, mid );
 34     else insert( v, son[nd][1], mid+1, rg );
 35     pushup(nd);
 36 }
 37 void erase( int v, int nd, int lf, int rg ) {
 38     fprintf( stderr, "erase( %d %d %d %d )\n", v, nd, lf, rg );
 39     if( lf==rg ) {
 40         siz[nd]--;
 41         return;
 42     }
 43     int mid = (lf+rg)>>1;
 44     if( v<=mid ) erase( v, son[nd][0], lf, mid );
 45     else erase( v, son[nd][1], mid+1, rg );
 46     pushup(nd);
 47 }
 48 int rank( int v, int nd, int lf, int rg ) {
 49     fprintf( stderr, "rank( %d %d %d %d )\n", v, nd, lf, rg );
 50     if( lf==rg ) return 1;
 51     int mid = (lf+rg)>>1;
 52     if( v<=mid ) return rank(v,son[nd][0],lf,mid);
 53     else return siz[son[nd][0]]+rank(v,son[nd][1],mid+1,rg);
 54 }
 55 int nth( int n, int nd, int lf, int rg ) {
 56     fprintf( stderr, "nth( %d %d %d %d )\n", n, nd, lf, rg );
 57     if( lf==rg ) return lf;
 58     int ls = siz[son[nd][0]];
 59     int mid = (lf+rg)>>1;
 60     if( n<=ls ) return nth(n,son[nd][0],lf,mid);
 61     else return nth(n-ls,son[nd][1],mid+1,rg);
 62 }
 63 int gnext( int v, int nd, int lf, int rg ) {
 64     fprintf( stderr, "gnext( %d %d %d %d )\n", v, nd, lf, rg );
 65     if( !nd ) return v;
 66     if( lf==rg ) return lf>v ? lf : v;
 67     int mid = (lf+rg)>>1;
 68     if( v<=mid ) {
 69         int rt = gnext( v, son[nd][0], lf, mid );
 70         if( rt==v ) return gnext( v, son[nd][1], mid+1, rg );
 71         else return rt;
 72     } 
 73     return gnext( v, son[nd][1], mid+1, rg );
 74 }
 75 int gprev( int v, int nd, int lf, int rg ) {
 76     fprintf( stderr, "gprev( %d %d %d %d )\n", v, nd, lf, rg );
 77     if( !nd ) return v;
 78     if( lf==rg ) return lf<v ? lf : v;
 79     int mid = (lf+rg)>>1;
 80     if( v>mid ) {
 81         int rt = gprev( v, son[nd][1], mid+1, rg );
 82         if( rt==v ) return gprev( v, son[nd][0], lf, mid );
 83         else return rt;
 84     }
 85     return gprev( v, son[nd][0], lf, mid );
 86 }
 87 
 88 int n;
 89 
 90 int main() {
 91     scanf( "%d", &n );
 92     newnode();
 93     for( int i=1,opt,x; i<=n; i++ ) {
 94         scanf( "%d%d", &opt, &x );
 95         x += inc;
 96         switch(opt) {
 97             case 1:
 98                 insert( x, 1, minv, maxv );
 99                 break;
100             case 2:
101                 erase( x, 1, minv, maxv );
102                 break;
103             case 3:
104                 printf( "%d\n", rank(x,1,minv,maxv) );
105                 break;
106             case 4:
107                 printf( "%d\n", nth(x-inc,1,minv,maxv)-inc );
108                 break;
109             case 5:
110                 printf( "%d\n", gprev(x,1,minv,maxv)-inc );
111                 break;
112             case 6:
113                 printf( "%d\n", gnext(x,1,minv,maxv)-inc );
114                 break;
115         }
116     }
117 }
View Code

splay:

  1 #include <cstdio>
  2 #include <iostream>
  3 #define maxn 100020
  4 using namespace std;
  5 
  6 
  7 struct Splay {
  8     int key[maxn], pre[maxn], son[maxn][2], siz[maxn], cnt[maxn], root, ntot;
  9 
 10     int newnode( int k, int p ) {
 11         int nd = ++ntot;
 12         key[nd] = k;
 13         pre[nd] = p;
 14         son[nd][0] = son[nd][1] = 0;
 15         siz[nd] = cnt[nd] = 1;
 16         return nd;
 17     }
 18     void update( int nd ) {
 19         siz[nd] = siz[son[nd][0]] + siz[son[nd][1]] + cnt[nd];
 20     }
 21     void rotate( int nd, int d ) {
 22         int p = pre[nd];
 23         int s = son[nd][!d];
 24         int ss = son[s][d];
 25 
 26         son[nd][!d] = ss;
 27         son[s][d] = nd;
 28         if( p ) son[p][ nd==son[p][1] ] = s;
 29         else root = s;
 30 
 31         pre[nd] = s;
 32         pre[s] = p;
 33         if( ss ) pre[ss] = nd;
 34 
 35         update(nd);
 36         update(s);
 37     }
 38     void splay( int nd, int top ) {
 39         while( pre[nd]!=top ) {
 40             int p = pre[nd];
 41             int nl = nd==son[p][0];
 42             if( pre[p]==top ) {
 43                 rotate( p, nl );
 44             } else {
 45                 int pp = pre[p];
 46                 int pl = p==son[pp][0];
 47                 if( nl==pl ) {
 48                     rotate( pp, pl );
 49                     rotate( p, nl );
 50                 } else {
 51                     rotate( p, nl );
 52                     rotate( pp, pl );
 53                 }
 54             }
 55         }
 56     }
 57     void insert( int k ) {
 58         if( !root ) {
 59             root = newnode( k, 0 );
 60             return;
 61         }
 62         int nd = root;
 63         while( 1 ) {
 64             if( k==key[nd] ) {
 65                 cnt[nd]++;
 66                 break;
 67             } else {
 68                 if( !son[nd][ k>key[nd] ] ) {
 69                     son[nd][ k>key[nd] ] = newnode( k, nd );
 70                     break;
 71                 } 
 72                 nd = son[nd][ k>key[nd] ];
 73             }
 74         }
 75         update( nd );
 76         splay( nd, 0 );
 77     }
 78     int find( int k ) {
 79         int nd = root;
 80         while( 1 ) {
 81             if( k!=key[nd] ) nd = son[nd][ k>key[nd] ];
 82             else break;
 83         }
 84         return nd;
 85     }
 86     void erase( int k ) {
 87         int nd = find(k);
 88         cnt[nd]--;
 89         update(nd);
 90         splay(nd,0);
 91     }
 92     int gnext( int k ) {
 93         insert( k );
 94         int nd = find( k );
 95         cnt[nd]--;
 96         update(nd);
 97         splay(nd,0);
 98         int rnd = son[nd][1];
 99         if( !rnd ) fprintf( stderr, "gnext k=%d\n", k );
100         int rt = key[rnd];
101         while( son[rnd][0] ) {
102             rnd = son[rnd][0];
103             rt = min( rt, key[rnd] );
104         }
105         splay( rnd,0 );
106         return rt;
107     }
108     int gprev( int k ) {
109         insert( k );
110         int nd = find( k );
111         cnt[nd]--;
112         update(nd);
113         splay(nd,0);
114         int lnd = son[nd][0];
115         if( !lnd ) fprintf( stderr, "gprev k=%d\n", k );
116         int rt = key[lnd];
117         while( son[lnd][1] ) {
118             lnd = son[lnd][1];
119             rt = max( rt, key[lnd] );
120         }
121         splay(lnd,0);
122         return rt;
123     }
124     int rank( int k ) {
125         int nd = root;
126         int rt = 0;
127         while(1) {
128             int ls = siz[son[nd][0]];
129             int cs = cnt[nd];
130             if( k==key[nd] ) {
131                 rt += ls+1;
132                 splay( nd,0 );
133                 return rt;
134             }
135             if( k<key[nd] ) {
136                 nd = son[nd][0];
137             } else {
138                 nd = son[nd][1];
139                 rt += ls+cs;
140             }
141         }
142     }
143     int nth( int n ) {
144         int nd = root;
145         while(1) {
146             int ls = siz[son[nd][0]];
147             int cs = cnt[nd];
148             if( n<=ls ) {
149                 nd = son[nd][0];
150             } else if( n<=ls+cs ) {
151                 splay( nd, 0 );
152                 return key[nd];
153             } else {
154                 n -= ls+cs;
155                 nd = son[nd][1];
156             }
157         }
158     }
159     void print( int nd ) {
160         if( !nd ) return;
161         print( son[nd][0] );
162         fprintf( stderr, "%d(%d)  ", key[nd], cnt[nd] );
163         print( son[nd][1] );
164     }
165 };
166 
167 int n;
168 Splay T;
169 int main() {
170     scanf( "%d", &n );
171     while(n--) {
172         int opt, x;
173         scanf( "%d%d", &opt, &x );
174         switch( opt ) {
175             case 1:
176     //            fprintf( stderr, "insert(%d)\n", x );
177                 T.insert(x);
178                 break;
179             case 2:
180     //            fprintf( stderr, "erase(%d)\n", x );
181                 T.erase(x);
182                 break;
183             case 3:
184                 printf( "%d\n", T.rank(x) );
185                 break;
186             case 4:
187                 printf( "%d\n", T.nth(x) );
188                 break;
189             case 5:
190                 printf( "%d\n", T.gprev(x) );
191                 break;
192             case 6:
193                 printf( "%d\n", T.gnext(x) );
194                 break;
195         }
196         //T.print( T.root );
197         //fprintf( stderr, "\n" );
198     }
199 }
View Code

vector:

 1 #include <cstdio>
 2 #include <vector>
 3 #include <algorithm>
 4 using namespace std;
 5 
 6 
 7 vector<int> vc;
 8 void insert( int x ) {
 9     vc.insert( lower_bound(vc.begin(),vc.end(),x), x );
10 }
11 void erase( int x ) {
12     vc.erase( lower_bound(vc.begin(),vc.end(),x) );
13 }
14 int rank( int x ) {
15     return lower_bound(vc.begin(),vc.end(),x)-vc.begin()+1;
16 }
17 int nth( int n ) {
18     return vc[n-1];
19 }
20 int prev( int x ) {
21     return *--lower_bound(vc.begin(),vc.end(),x);
22 }
23 int next( int x ) {
24     return *upper_bound(vc.begin(),vc.end(),x);
25 }
26 
27 int main() {
28     int n;
29     scanf( "%d", &n );
30     while(n--) {
31         int opt, x;
32         scanf( "%d%d", &opt, &x );
33         switch(opt) {
34             case 1:
35                 insert( x );
36                 break;
37             case 2:
38                 erase(x);
39                 break;
40             case 3:
41                 printf( "%d\n", rank(x) );
42                 break;
43             case 4:
44                 printf( "%d\n", nth(x) );
45                 break;
46             case 5:
47                 printf( "%d\n", prev(x) );
48                 break;
49             case 6:
50                 printf( "%d\n", next(x) );
51                 break;
52         }
53     }
54 }
View Code

数据生成器(有一个参数,是随机数种子):

 1 #include <cstdio>
 2 #include <set>
 3 #include <cstdlib>
 4 #define R(l,r) ((rand())%((r)-(l)+1)+(l))
 5 using namespace std;
 6 
 7 int all = 1000;   //指令数
 8 int n = 500;      //最开始的插入的数的数量
 9 int maxd = 200;  //最多删除多少个数
10 
11 int a[101100], mn, mx;
12 int d[101100];
13 
14 int main( int argc, char **argv ) {
15     srand(atoi(argv[1]));
16 
17     freopen( "input", "w", stdout );
18     printf( "%d\n", all );
19     int p = all-n;
20     for( int i=1; i<=n; i++ ) 
21         printf( "1 %d\n", a[i]=R(-n,n) );
22     int remain = maxd;
23     for( int i=1; i<=p; i++ ) {
24 AGAIN:
25         if( i%10000==0 ) fprintf( stderr, "%d\n", i );
26         int opt = R(2,6);
27         int x;
28         switch( opt ) {
29             case 2:
30                 if( remain ) {
31                     int ind;
32                     do
33                         ind = R(1,n);
34                     while( !(!d[ind] && a[ind]!=mn && a[ind]!=mx) );
35                     remain--;
36                     d[ind] = 1;
37                     printf( "2 %d\n", a[ind] );
38                 } else goto AGAIN;
39                 break;
40             case 3:
41                 {
42                     int ind;
43                     do 
44                         ind = R(1,n);
45                     while( !(!d[ind]) );
46                     printf( "3 %d\n", a[ind] );
47                 }
48                 break;
49             case 4:
50                 {
51                     set<int> st;
52                     for( int i=1; i<=n; i++ ) 
53                         if( !d[i] ) st.insert(a[i]);
54                     int x = R(1,st.size());
55                     printf( "4 %d\n", x );
56                 }
57                 break;
58             case 5:
59                 if( mx==mn ) goto AGAIN;
60                 {
61                     printf( "5 %d\n", R(mn+1,mx) );
62                 }
63                 break;
64             case 6:
65                 if( mx==mn ) goto AGAIN;
66                 {
67                     printf( "6 %d\n", R(mn,mx-1) );
68                 }
69                 break;
70         }
71     }
72 }
View Code

 


20150708

终于填了这个坑。

将splay中rank(k)的返回值定义为小于k的数的个数加一。

然后k的前趋就是nth(rank(k)-1)

k的后继就是nth(rank(k+1))

  1 /**************************************************************
  2     Problem: 3224
  3     User: idy002
  4     Language: C++
  5     Result: Accepted
  6     Time:776 ms
  7     Memory:2772 kb
  8 ****************************************************************/
  9  
 10 #include <cstdio>
 11  
 12 const int N = 100000 + 10;
 13  
 14 struct Splay {
 15     int son[N][2], fat[N], key[N], siz[N], ntot, root;
 16  
 17     inline void update( int nd ) {
 18         siz[nd] = siz[son[nd][0]] + siz[son[nd][1]] + 1;
 19     }
 20     void rotate( int nd, int d ) {
 21         int p = fat[nd];
 22         int s = son[nd][!d];
 23         int ss = son[s][d];
 24  
 25         if( p ) son[p][ nd==son[p][1] ] = s;
 26         else root=s;
 27         son[nd][!d] = ss;
 28         son[s][d] = nd;
 29  
 30         fat[nd] = s;
 31         fat[s] = p;
 32         if( ss ) fat[ss]=nd;
 33  
 34         update(nd);
 35         update(s);
 36     }
 37     void splay( int nd, int top=0 ) {
 38         while( fat[nd]!=top ) {
 39             int p=fat[nd];
 40             int nl=nd==son[p][0];
 41             if( fat[p]==top ) {
 42                 rotate( p, nl );
 43             } else {
 44                 int pp=fat[p];
 45                 int pl=p==son[pp][0];
 46                 if( nl==pl ) {
 47                     rotate( pp, pl );
 48                     rotate( p, nl );
 49                 } else {
 50                     rotate( p, nl );
 51                     rotate( pp, pl );
 52                 }
 53             }
 54         }
 55     }
 56     int newnode( int p, int k ) {
 57         int nd = ++ntot;
 58         key[nd] = k;
 59         siz[nd] = 1;
 60         fat[nd] = p;
 61         son[nd][0] = son[nd][1] = 0;
 62         return nd;
 63     }
 64     void insert( int v ) {
 65         if( !root ) {
 66             root = newnode( 0, v );
 67             return;
 68         }
 69         int nd = root;
 70         while( son[nd][v>key[nd]] ) nd=son[nd][v>key[nd]];
 71         son[nd][v>key[nd]] = newnode( nd, v );
 72         update(nd);
 73         splay(nd);
 74     }
 75     int find( int v ) {
 76         int nd = root;
 77         while( key[nd]!=v ) nd=son[nd][ v>key[nd] ];
 78         return nd;
 79     }
 80     void erase( int v ) {
 81         int nd = find(v);
 82         splay(nd);
 83         int lnd = son[nd][0];
 84         int rnd = son[nd][1];
 85         if( !lnd && !rnd ) {
 86             root = 0;
 87         } else if( !lnd ) {
 88             root = rnd;
 89             fat[rnd] = 0;
 90         } else if( !rnd ) {
 91             root = lnd;
 92             fat[lnd] = 0;
 93         } else {
 94             while( son[lnd][1] ) lnd=son[lnd][1];
 95             while( son[rnd][0] ) rnd=son[rnd][0];
 96             splay( lnd, 0 );
 97             splay( rnd, lnd );
 98             son[rnd][0] = 0;
 99             update(rnd);
100             update(lnd);
101         }
102     }
103     int nth( int k ) {
104         int nd=root;
105         while(1) {
106             int lz = siz[son[nd][0]];
107             if( k<=lz ) {
108                 nd = son[nd][0];
109             } else if( k>=lz+2 ) {
110                 k -= lz+1;
111                 nd = son[nd][1];
112             } else {
113                 splay(nd);
114                 return key[nd];
115             }
116         }
117     }
118     int rank( int v ) {
119         int nd=root;
120         int rt = 1;
121         int last_nd;
122         while(nd) {
123             last_nd = nd;
124             if( key[nd]<v ) {
125                 rt += siz[son[nd][0]] + 1;
126                 nd = son[nd][1];
127             } else {
128                 nd = son[nd][0];
129             }
130         }
131         splay(last_nd);
132         return rt;
133     }
134     int prev( int v ) {
135         int k = rank(v);
136         return nth(k-1);
137     }
138     int succ( int v ) {
139         int k = rank(v+1);
140         return nth(k);
141     }
142 }T;
143  
144 int n;
145  
146 int main() {
147     scanf( "%d", &n );
148     for( int i=1,opt,x; i<=n; i++ ) {
149         scanf( "%d%d", &opt, &x );
150         if( opt==1 ) {
151             T.insert(x);
152         } else if( opt==2 ) {
153             T.erase(x);
154         } else if( opt==3 ) {
155             printf( "%d\n", T.rank(x) );
156         } else if( opt==4 ) {
157             printf( "%d\n", T.nth(x) );
158         } else if( opt==5 ) {
159             printf( "%d\n", T.prev(x) );
160         } else {
161             printf( "%d\n", T.succ(x) );
162         }
163     }
164 }
View Code

 

Update:2018—01—18

 

来自好心人:我们发现这个Splay的删除是惰性删除的,并且很多地方的更改是无效的。这也就导致了一个问题:已经存在于这棵Splay里面的一些节点其实在原来的序列中是不存在的。那么这就导致了如果直接在树上找前驱后继时遍历节点会遍历到不在原序列中的数。于是就GG了。

In:

5

1 1
1 2
1 3
2 2
5 3

Out:
2

但应该是1对吧......

posted @ 2015-02-09 18:45  idy002  阅读(296)  评论(0编辑  收藏  举报