PAT 1057. Stack (30)

http://www.patest.cn/contests/pat-a-practise/1057

treap 求第k大的数

  1 #include<cstdio>
  2 #include<stack>
  3 #include<cstring>
  4 #include<ctime>
  5 #include<cstdlib>
  6 
  7 using namespace std;
  8 
  9 struct Node {
 10     Node *ch[2];
 11     int rank;
 12     int value;
 13     int cnt;
 14     int size;
 15 
 16     Node():rank(0), value(0), cnt(0), size(0) {}
 17 
 18     bool operator < (const Node& other)  const {
 19         return rank < other.rank;
 20     }
 21 
 22     int cmp(int x) const {
 23         if (x == value) {
 24             return -1;
 25         }
 26         return x < value? 0 : 1;
 27     }
 28 
 29     void maintain() {
 30         size = cnt;
 31         if (ch[0] != NULL) size += ch[0]->size;
 32         if (ch[1] != NULL) size += ch[1]->size;
 33     }
 34 };
 35 
 36 
 37 void rotate(Node* &o, int d) {
 38     Node* k = o->ch[d ^ 1];
 39     o->ch[d ^ 1] = k->ch[d];
 40     k->ch[d] = o;
 41     o->maintain();
 42     k->maintain();
 43     o = k;
 44 }
 45 
 46 void insert(Node* &o, int x) {
 47     if (o == NULL) {
 48         o = new Node();
 49         o->ch[0] = NULL;
 50         o->ch[1] = NULL;
 51         o->value = x;
 52         o->cnt = 1;
 53         o->size = 1;
 54         o->rank = rand();
 55 #ifdef DEBUG
 56         printf("rank = %d\n", o->rank);
 57 #endif
 58     } else {
 59         int d = o->cmp(x);
 60         if (d == -1) {
 61             o->cnt++;
 62         } else {
 63             insert(o->ch[d], x);
 64             if (o->rank < o->ch[d]->rank) {
 65                 rotate(o, d ^ 1);
 66             }
 67         }
 68     }
 69     o->maintain();
 70 }
 71 
 72 void remove(Node* &o, int x) {
 73     int d = o->cmp(x);
 74     if (d == -1) {
 75         Node* u = o;
 76         if (o->cnt > 1) {
 77             o->cnt--;
 78         } else if (o->ch[0] != NULL && o->ch[1] != NULL) {
 79             int d2 = (o->ch[0]->rank > o->ch[1]->rank ? 1 : 0);
 80             rotate(o, d2);
 81             remove(o->ch[d2], x);
 82         } else {
 83             if (o->ch[0] == NULL) o = o->ch[1];
 84             else o = o->ch[0];
 85             delete u;
 86         }
 87     } else { 
 88         remove(o->ch[d], x);
 89     }
 90     if (o != NULL) o->maintain();
 91 }
 92 
 93 //for debug
 94 int find(Node* o, int x) {
 95     while(o != NULL) {
 96         int d = o->cmp(x);
 97         if (d == -1) return 1;
 98         else o = o->ch[d];
 99     }
100     return 0;
101 }
102 
103 int kth(Node* o, int k) {
104     if (o == NULL || k <= 0 || k > o->size) return 0;
105     int s = o->ch[0] == NULL ? 0 : o->ch[0]->size;
106     if (k > s && k <= s + o->cnt) {
107         return o->value;
108     } else if (k <= s) {
109         return kth(o->ch[0], k);
110     } else return kth(o->ch[1], k - s - o->cnt);
111 }
112 
113 int mymax(int a, int b) {
114     return a > b ? a : b;
115 }
116 
117 //for debug
118 int max_depth(Node* root) {
119     if (root == NULL) {
120         return 0;
121     }
122     Node *p = root;
123     int dep = 0;
124     while(NULL != p) {
125         p = p->ch[0];
126         ++dep;
127     }
128     printf("left = %d    ", dep);
129     p = root;
130     dep = 0;
131     while(NULL != p) {
132         p = p->ch[1];
133         ++dep;
134     }
135     printf("right = %d\n", dep);
136 }
137 
138 Node* root = NULL;
139 stack<int> st;
140 int main() {
141     root = NULL;
142 #ifdef DEBUG
143     printf("rand_max = %d\n", RAND_MAX);
144 #endif
145     freopen("input", "r",stdin);
146     srand(time(NULL));
147     char op[20];
148     int n;
149     scanf("%d", &n);
150     for (int i = 0; i < n; ++i) {
151         scanf("%s", op);
152 #ifdef DEBUG
153         puts(op);
154 #endif
155         if (strcmp(op, "Push") == 0) {
156             int key;
157             scanf("%d", &key);
158             st.push(key);
159             insert(root, key);
160 #ifdef DEBUG
161             printf("find %d  %d\n", key, find(root, key));
162 #endif
163         } else if (strcmp(op, "Pop") == 0) {
164             if (st.empty()) {
165                 printf("Invalid\n");
166             } else {
167                 int v = st.top();
168                 st.pop();
169                 remove(root, v);
170                 printf("%d\n", v);
171             }
172         } else if (strcmp(op, "PeekMedian") == 0) {
173             if (st.empty()) {
174                 printf("Invalid\n");
175             } else {
176                 int size = st.size();
177                 int k;
178                 if (size & 1) {
179                     k = (size + 1) >> 1;
180                 } else k = size >> 1;
181 #ifdef DEBUG
182                 printf("k = %d\n", k);
183 #endif
184                 printf("%d\n", kth(root, k));
185             }
186         } else printf("Invalid\n");
187     }
188 #ifdef DEBUG
189     max_depth(root);
190 #endif
191     return 0;
192 }

 

 

树状数组的解法:

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<cstdlib>
 4 #include<stack>
 5 
 6 using namespace std;
 7 
 8 const int MAXN = 100000 + 10;
 9 int c[MAXN]; //binary indexed array
10 
11 int lowbit(int x) {
12     return x & (-x);
13 }
14 
15 int sum(int x) {
16     int sum = 0;
17     while(x > 0) {
18         sum += c[x];
19         x -= lowbit(x);
20     }
21     return sum;
22 }
23 
24 void add(int p, int x) {
25     while(p < MAXN) {
26         c[p] += x;
27         p += lowbit(p);
28     }
29 }
30 
31 int kth(int k) { //find the kth number by binary search, remember find the leftmost 
32     int low = 1, high = MAXN - 1;
33     while(low < high) {
34         int mid = low + ((high - low) >> 1);
35         int s = sum(mid);
36         if(s < k) {
37             low = mid + 1;
38         } else high = mid;
39     }
40     return low;
41 }
42 
43 stack<int> st;
44 int main() {
45     freopen("input", "r", stdin);
46     memset(c, 0, sizeof(c));
47     char op[20];
48     int n;
49     scanf("%d", &n);
50     for (int kase = 0; kase < n; ++kase) {
51         scanf("%s", op);
52         if (op[1] == 'u') {
53             int k;
54             scanf("%d", &k);
55             st.push(k);
56             add(k, 1);
57         } else if (op[1] == 'o') {
58             if (st.empty()) {
59                 printf("Invalid\n");
60             } else {
61                 int k = st.top();
62                 st.pop();
63                 add(k, -1);
64                 printf("%d\n", k);
65             }
66         } else {
67             if (st.empty()) {
68                 printf("Invalid\n");
69             } else {
70                 int sz = st.size();
71                 printf("%d\n", kth((sz + 1) >> 1));
72             }
73         }
74     }
75     return 0;
76 }

 

posted @ 2015-07-31 22:57  ACSeed  Views(254)  Comments(0)    收藏  举报