PAT 1057. Stack (30)
http://www.patest.cn/contests/pat-a-practise/1057
treap 求第k大的数
1 #include<cstdio> 2 #include<stack> 3 #include<cstring> 4 #include<ctime> 5 #include<cstdlib> 6 7 using namespace std; 8 9 struct Node { 10 Node *ch[2]; 11 int rank; 12 int value; 13 int cnt; 14 int size; 15 16 Node():rank(0), value(0), cnt(0), size(0) {} 17 18 bool operator < (const Node& other) const { 19 return rank < other.rank; 20 } 21 22 int cmp(int x) const { 23 if (x == value) { 24 return -1; 25 } 26 return x < value? 0 : 1; 27 } 28 29 void maintain() { 30 size = cnt; 31 if (ch[0] != NULL) size += ch[0]->size; 32 if (ch[1] != NULL) size += ch[1]->size; 33 } 34 }; 35 36 37 void rotate(Node* &o, int d) { 38 Node* k = o->ch[d ^ 1]; 39 o->ch[d ^ 1] = k->ch[d]; 40 k->ch[d] = o; 41 o->maintain(); 42 k->maintain(); 43 o = k; 44 } 45 46 void insert(Node* &o, int x) { 47 if (o == NULL) { 48 o = new Node(); 49 o->ch[0] = NULL; 50 o->ch[1] = NULL; 51 o->value = x; 52 o->cnt = 1; 53 o->size = 1; 54 o->rank = rand(); 55 #ifdef DEBUG 56 printf("rank = %d\n", o->rank); 57 #endif 58 } else { 59 int d = o->cmp(x); 60 if (d == -1) { 61 o->cnt++; 62 } else { 63 insert(o->ch[d], x); 64 if (o->rank < o->ch[d]->rank) { 65 rotate(o, d ^ 1); 66 } 67 } 68 } 69 o->maintain(); 70 } 71 72 void remove(Node* &o, int x) { 73 int d = o->cmp(x); 74 if (d == -1) { 75 Node* u = o; 76 if (o->cnt > 1) { 77 o->cnt--; 78 } else if (o->ch[0] != NULL && o->ch[1] != NULL) { 79 int d2 = (o->ch[0]->rank > o->ch[1]->rank ? 1 : 0); 80 rotate(o, d2); 81 remove(o->ch[d2], x); 82 } else { 83 if (o->ch[0] == NULL) o = o->ch[1]; 84 else o = o->ch[0]; 85 delete u; 86 } 87 } else { 88 remove(o->ch[d], x); 89 } 90 if (o != NULL) o->maintain(); 91 } 92 93 //for debug 94 int find(Node* o, int x) { 95 while(o != NULL) { 96 int d = o->cmp(x); 97 if (d == -1) return 1; 98 else o = o->ch[d]; 99 } 100 return 0; 101 } 102 103 int kth(Node* o, int k) { 104 if (o == NULL || k <= 0 || k > o->size) return 0; 105 int s = o->ch[0] == NULL ? 0 : o->ch[0]->size; 106 if (k > s && k <= s + o->cnt) { 107 return o->value; 108 } else if (k <= s) { 109 return kth(o->ch[0], k); 110 } else return kth(o->ch[1], k - s - o->cnt); 111 } 112 113 int mymax(int a, int b) { 114 return a > b ? a : b; 115 } 116 117 //for debug 118 int max_depth(Node* root) { 119 if (root == NULL) { 120 return 0; 121 } 122 Node *p = root; 123 int dep = 0; 124 while(NULL != p) { 125 p = p->ch[0]; 126 ++dep; 127 } 128 printf("left = %d ", dep); 129 p = root; 130 dep = 0; 131 while(NULL != p) { 132 p = p->ch[1]; 133 ++dep; 134 } 135 printf("right = %d\n", dep); 136 } 137 138 Node* root = NULL; 139 stack<int> st; 140 int main() { 141 root = NULL; 142 #ifdef DEBUG 143 printf("rand_max = %d\n", RAND_MAX); 144 #endif 145 freopen("input", "r",stdin); 146 srand(time(NULL)); 147 char op[20]; 148 int n; 149 scanf("%d", &n); 150 for (int i = 0; i < n; ++i) { 151 scanf("%s", op); 152 #ifdef DEBUG 153 puts(op); 154 #endif 155 if (strcmp(op, "Push") == 0) { 156 int key; 157 scanf("%d", &key); 158 st.push(key); 159 insert(root, key); 160 #ifdef DEBUG 161 printf("find %d %d\n", key, find(root, key)); 162 #endif 163 } else if (strcmp(op, "Pop") == 0) { 164 if (st.empty()) { 165 printf("Invalid\n"); 166 } else { 167 int v = st.top(); 168 st.pop(); 169 remove(root, v); 170 printf("%d\n", v); 171 } 172 } else if (strcmp(op, "PeekMedian") == 0) { 173 if (st.empty()) { 174 printf("Invalid\n"); 175 } else { 176 int size = st.size(); 177 int k; 178 if (size & 1) { 179 k = (size + 1) >> 1; 180 } else k = size >> 1; 181 #ifdef DEBUG 182 printf("k = %d\n", k); 183 #endif 184 printf("%d\n", kth(root, k)); 185 } 186 } else printf("Invalid\n"); 187 } 188 #ifdef DEBUG 189 max_depth(root); 190 #endif 191 return 0; 192 }
树状数组的解法:
1 #include<cstdio> 2 #include<cstring> 3 #include<cstdlib> 4 #include<stack> 5 6 using namespace std; 7 8 const int MAXN = 100000 + 10; 9 int c[MAXN]; //binary indexed array 10 11 int lowbit(int x) { 12 return x & (-x); 13 } 14 15 int sum(int x) { 16 int sum = 0; 17 while(x > 0) { 18 sum += c[x]; 19 x -= lowbit(x); 20 } 21 return sum; 22 } 23 24 void add(int p, int x) { 25 while(p < MAXN) { 26 c[p] += x; 27 p += lowbit(p); 28 } 29 } 30 31 int kth(int k) { //find the kth number by binary search, remember find the leftmost 32 int low = 1, high = MAXN - 1; 33 while(low < high) { 34 int mid = low + ((high - low) >> 1); 35 int s = sum(mid); 36 if(s < k) { 37 low = mid + 1; 38 } else high = mid; 39 } 40 return low; 41 } 42 43 stack<int> st; 44 int main() { 45 freopen("input", "r", stdin); 46 memset(c, 0, sizeof(c)); 47 char op[20]; 48 int n; 49 scanf("%d", &n); 50 for (int kase = 0; kase < n; ++kase) { 51 scanf("%s", op); 52 if (op[1] == 'u') { 53 int k; 54 scanf("%d", &k); 55 st.push(k); 56 add(k, 1); 57 } else if (op[1] == 'o') { 58 if (st.empty()) { 59 printf("Invalid\n"); 60 } else { 61 int k = st.top(); 62 st.pop(); 63 add(k, -1); 64 printf("%d\n", k); 65 } 66 } else { 67 if (st.empty()) { 68 printf("Invalid\n"); 69 } else { 70 int sz = st.size(); 71 printf("%d\n", kth((sz + 1) >> 1)); 72 } 73 } 74 } 75 return 0; 76 }

浙公网安备 33010602011771号