【POJ2104】K-th Number

【POJ2104】K-th Number

题面

virtual judge

题解

其实就是一道主席树\(sb\)

但是为了学习整体二分的需要就用整体二分写了。。。

所以主要利用此题讲一下整体二分到底是个啥(以下部分参考李煜东《算法竞赛进阶指南》):

两个例子

\(Eg1\)

给定一个正整数序列\(A\)及固定的整数\(S\),执行\(M\)此操作

每次查询\(l\)~\(r\)间不大于\(S\)的数或将\(A[x]\)改为\(y\)

很简单吧。。。

用树状数组维护一下就好了吧。。。

\(Eg2\)

给定一个正整数序列,求此序列的第\(K\)小数是多少。

看到这里也许你觉得我是个傻逼。。。直接排一遍序就好了啊

但是,为了引入整体二分使问题复杂化,我们采用第二种方法:

二分答案,设当前二分值为\(mid\),统计有多少个数\(\leq mid\),记为\(cnt\)

1.若\(K \leq cnt\),则说明K小数值一定\(\in\)\([l, mid]\),可在左半区间继续二分

2.若\(K > cnt\),则最小数一定\(\in\)\([mid+1, r]\),等价于在值域\([mid+1,r]\)下寻找\(K-cnt\)小的数

复杂度\(N\) \(logSIZE\)

回到\(POJ2104\),要求\(M\)个形如“求序列\(A\)\(l\)\(r\)个数中第\(k\)小的数”,这样做\(M\)次显然是不行的

而这样做\(M\)次中会有大量冗余状态,于是就有了---整体二分

整体二分

对于此题,

我们套用\(Eg2\)的做法

尝试在序列\(A\)中值域\([MINA,MAXA]\)二分答案\(mid\)

记区间\(l_i\)\(r_i\)中小于等于\(mid\)的数有\(c_i\)

然后将这些询问分类:

1.若\(k_i\) \(\leq\) \(c_i\),则说明第\(i\)个询问的答案在\([MINA,mid]\)

2.若\(k_i>c_i\),则说明第\(i\)个询问的答案在\([mid+1,MAXA]\)中,且等价于在值域\([mid+1,MAXA]\)中查询第\(k_i-c_i\)小的数

然后分别把上面两类分为子序列\(LA\)\(RA\),分开处理即可

对于统计\(c_i\)可以利用\(Eg1\)中的树状数组维护

具体实现还是看代码吧,感觉越讲越懵啊。。。

代码

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;

inline int gi() {
    register int data = 0, w = 1;
    register char ch = 0;
    while (ch != '-' && (ch > '9' || ch < '0')) ch = getchar();
    if (ch == '-') w = -1 , ch = getchar();
    while (ch >= '0' && ch <= '9') data = data * 10 + (ch ^ 48), ch = getchar();
    return w * data;
}
const int MAX_N = 200005, INF = 1e9; 
struct rec {int op, x, y, z; } q[MAX_N << 1], lq[MAX_N << 1], rq[MAX_N << 1]; 
int N, M, tot, c[MAX_N], ans[MAX_N]; 
inline int lb(int x) { return x & -x; } 
void add(int x, int v) { while (x <= N) c[x] += v, x += lb(x); } 
int sum(int x) { int res = 0; while (x > 0) res += c[x], x -= lb(x); return res; } 
int X[MAX_N], cnt; 
void Div(int lval, int rval, int st, int ed) { 
    if (st > ed) return ; 
    if (lval == rval) {  
        for (int i = st; i <= ed; i++) 
            if (q[i].op > 0) ans[q[i].op] = lval; 
        return ; 
    } 
    int mid = (lval + rval) >> 1; 
    int lt = 0, rt = 0; 
    for (int i = st; i <= ed; i++) { 
    	if (q[i].op == 0) { 
    	    if (q[i].y <= mid) add(q[i].x, 1), lq[++lt] = q[i]; 
    	    else rq[++rt] = q[i]; 
        } else { 
            int res = sum(q[i].y) - sum(q[i].x - 1); 
            if (res >= q[i].z) lq[++lt] = q[i]; 
            else q[i].z -= res, rq[++rt] = q[i]; 
        } 
    } 
    for (int i = st; i <= ed; i++) { 
        if (q[i].op == 0 && q[i].y <= mid) add(q[i].x, -1); 
    } 
    for (int i = 1; i <= lt; i++) q[st + i - 1] = lq[i]; 
    for (int i = 1; i <= rt; i++) q[st + lt + i - 1] = rq[i]; 
    Div(lval, mid, st, st + lt - 1); 
    Div(mid + 1, rval, st + lt, ed); 
} 
int main () { 
    N = gi(), M = gi(); 
    for (int i = 1; i <= N; i++) { 
        int v = gi(); 
        q[++tot].op = 0, q[tot].x = i, X[++cnt] = q[tot].y = v; 
    } 
    sort(&X[1], &X[cnt + 1]); cnt = unique(&X[1], &X[cnt + 1]) - X - 1; 
    for (int i = 1; i <= N; i++) q[i].y = lower_bound(&X[1], &X[cnt + 1], q[i].y) - X; 
    for (int i = 1; i <= M; i++) { 
        q[++tot].op = i, q[tot].x = gi(), q[tot].y = gi(), q[tot].z = gi(); 
    } 
    Div(1, N, 1, tot); 
    for (int i = 1; i <= M; i++) printf("%d\n", X[ans[i]]); 
    return 0; 
} 

另附主席树代码

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;

inline int gi() {
    register int data = 0, w = 1;
    register char ch = 0;
    while (ch != '-' && (ch > '9' || ch < '0')) ch = getchar();
    if (ch == '-') w = -1 , ch = getchar();
    while (ch >= '0' && ch <= '9') data = data * 10 + (ch ^ 48), ch = getchar();
    return w * data;
}
#define MAX_N 200005
struct Node {
    int ls, rs, val; 
} t[MAX_N << 5];
int cnt = 0, rt[MAX_N << 5];
void build(int &o, int l, int r) {
    o = ++cnt;
    if (l == r) return ;
    int mid = (l + r) >> 1; 
    build(t[o].ls, l, mid); 
    build(t[o].rs, mid + 1, r); 
} 
void insert(int &o, int pre, int l, int r, int x) { 
    o = ++cnt;
    t[o].ls = t[pre].ls; t[o].rs = t[pre].rs; t[o].val = t[pre].val; 
    t[o].val++; 
    if (l == r) return ; 
    int mid = (l + r) >> 1; 
    if (x <= mid) insert(t[o].ls, t[pre].ls, l, mid, x);
    else insert(t[o].rs, t[pre].rs, mid + 1, r, x); 
}
int query(int u, int v, int l, int r, int k) {
    if (l == r) return l;
    int sz = t[t[u].ls].val - t[t[v].ls].val;
    int mid = (l + r) >> 1; 
    if (sz < k) return query(t[u].rs, t[v].rs, mid + 1, r, k - sz);
    else return query(t[u].ls, t[v].ls, l, mid, k); 
}
int N, M, a[MAX_N];
int X[MAX_N]; 
int main () {
    N = gi(), M = gi();
    for (int i = 1; i <= N; i++) X[i] = a[i] = gi();
    sort(&X[1], &X[N + 1]);
    int size = unique(&X[1], &X[N + 1]) - X - 1; 
    build(rt[0], 1, N); 
    for (int i = 1; i <= N; i++) {
        int x = lower_bound(&X[1], &X[size + 1], a[i]) - X;
        insert(rt[i], rt[i - 1], 1, N, x); 
    }
    while (M--) {
        int l = gi(), r = gi(), v = gi();
        printf("%d\n", X[query(rt[r], rt[l - 1], 1, N, v)]); 
    }
    return 0; 
} 

posted @ 2018-12-14 21:54  heyujun  阅读(276)  评论(0编辑  收藏  举报