AcWing 2476. 树套树(线段树套splay)

题目链接

解题思路

  板子题,二分那里坑了,一定要注意一下。

const int maxn = 1e5+10;
int n, m, a[maxn];
struct Node {
    int s[2], v, sz, p;
    void init(int _v, int _p) {
        v = _v, p = _p;
        sz = 1;
    }
} tr[maxn*200];
int idx, rts[maxn*200];
inline void push_up(int u) {
    tr[u].sz = tr[tr[u].s[0]].sz+tr[tr[u].s[1]].sz+1;
}
inline void rotate(int x) {
    int y = tr[x].p, z = tr[y].p;
    int k = tr[y].s[1] == x;
    tr[z].s[tr[z].s[1]==y] = x, tr[x].p = z;
    tr[y].s[k] = tr[x].s[k^1], tr[tr[x].s[k^1]].p = y;
    tr[x].s[k^1] = y, tr[y].p = x;
    push_up(y), push_up(x);
}
void splay(int &rt, int x, int k) {
    while(tr[x].p!=k) {
        int y = tr[x].p, z = tr[y].p;
        if (z!=k)
            if ((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
            else rotate(y);
        rotate(x);
    }
    if (!k) rt = x;
}
void insert(int &rt, int v) {
    int u = rt, p = 0;
    while(u) p = u, u = tr[u].s[v>tr[u].v];
    u = ++idx;
    if (p) tr[p].s[v>tr[p].v] = u;
    tr[u].init(v, p);
    splay(rt, u, 0);
}
void change(int &rt, int pre, int now) {
    int u = rt; 
    //因为要保证有序性,所以不能直接修改原来的值,应该先删除再插入
    while(u) {
        if (pre>tr[u].v) u = tr[u].s[1];
        else if (pre==tr[u].v) break;
        else u = tr[u].s[0];
    }
    splay(rt, u, 0);
    int l = tr[u].s[0], r = tr[u].s[1];
    //找到原来的数的前驱和后继,后缀的左儿子就是要删的数
    while(tr[l].s[1]) l = tr[l].s[1];
    while(tr[r].s[0]) r = tr[r].s[0];
    splay(rt, l, 0), splay(rt, r, l);
    tr[r].s[0] = 0;
    push_up(r), push_up(l);
    insert(rt, now);
}
int get_rank(int rt, int x) {
    int u = rt, sum = 0; //查询所有小于x的数的数量,+1即是排名
    while(u) {
        if (x>tr[u].v)  {
            sum += tr[tr[u].s[0]].sz+1;
            u = tr[u].s[1];
        }
        else u = tr[u].s[0];
    }
    return sum;
}
int get_pre(int rt, int x) {
    int u = rt, maxx = -1;
    while(u) {
        if (x>tr[u].v) {
            maxx = max(maxx, tr[u].v);
            u = tr[u].s[1];
        }
        else u = tr[u].s[0];
    }
    return maxx;
}
int get_suc(int rt, int x) {
    int u = rt, minn = INF;
    while(u) {
        if (x<tr[u].v) {
            minn = min(minn, tr[u].v);
            u = tr[u].s[0];
        }
        else u = tr[u].s[1];
    }
    return minn;
}
void build(int rt, int l, int r) {
    insert(rts[rt], -INF);
    insert(rts[rt], INF);
    for (int i = l; i<=r; ++i) insert(rts[rt], a[i]);
    if (l==r) return;
    int mid = (l+r)>>1;
    build(rt<<1, l, mid);
    build(rt<<1|1, mid+1, r);
}
void update(int rt, int l, int r, int pos, int p, int q) {
    change(rts[rt], p, q);
    if (l==r) return;
    int mid = (l+r)>>1;
    if (pos<=mid) update(rt<<1, l, mid, pos, p, q);
    else update(rt<<1|1, mid+1, r, pos, p, q);
}
int query_rank(int rt, int l, int r, int L, int R, int x) {
    if (l>=L && r<=R) return get_rank(rts[rt], x)-1;
    int mid = (l+r)>>1, sum = 0;
    if (L<=mid) sum += query_rank(rt<<1, l, mid, L, R, x);
    if (R>mid) sum += query_rank(rt<<1|1, mid+1, r, L, R, x);
    return sum;
}
int query_pre(int rt, int l, int r, int L, int R, int x) {
    if (l>=L && r<=R) return get_pre(rts[rt], x);
    int mid = (l+r)>>1, maxx = -1;
    if (L<=mid) maxx = max(maxx, query_pre(rt<<1, l, mid, L, R, x));
    if (R>mid) maxx = max(maxx, query_pre(rt<<1|1, mid+1, r, L, R, x));
    return maxx;
}
int query_suc(int rt, int l, int r, int L, int R, int x) {
    if (l>=L && r<=R) return get_suc(rts[rt], x);
    int mid = (l+r)>>1, minn = INF;
    if (L<=mid) minn = min(minn, query_suc(rt<<1, l, mid, L, R, x));
    if (R>mid) minn = min(minn, query_suc(rt<<1|1, mid+1, r, L, R, x));
    return minn;
}
int main() {
    IOS;
    cin >> n >> m;
    for (int i = 1; i<=n; ++i) cin >> a[i];
    build(1, 1, n);
    int op, l, r, x;
    while(m--) {
        cin >> op >> l >> r;
        if (op==1) {
            //查询排名(从小到大数是第几个,相同看第一个)
            cin >> x;
            cout << query_rank(1, 1, n, l, r, x)+1 << endl;
        }
        else if (op==2) {
            cin >> x; //查询排名的值
            int L = 1, R = 10;
            while(L<R) {
                int mid = (L+R+1)>>1;
                if (query_rank(1, 1, n, l, r, mid)+1<=x) L = mid;
                else R = mid-1;
            }
            /*
                下面的二分是错误的,比如
                    9 3
                    4 2 2 1 9 4 0 1 1
                    2 1 4 3
                    3 4 10
                    2 1 4 3
                这组数据,第一个询问二分到3的时候,严格小于3的数有两个,算上3等于3个,但是实际上没有3这个数,
                第二个询问也类似,而用上面的二分就不会有问题。
                while(L<R) {
                    cout << L << ' ' << R << endl;
                    int mid = (L+R)>>1;
                    if (query_rank(1, 1, n, l, r, mid)+1>=x) R = mid;
                    else L = mid+1;
                }
            */
            cout << L << endl;
        }
        else if (op==3) {
            //修改某个位置上的数
            update(1, 1, n, l, a[l], r);
            a[l] = r; //一定要跟新一下原来的数
        }
        else if (op==4) {
            cin >> x; //查询前驱
            cout << query_pre(1, 1, n, l, r, x) << endl;
        }
        else if (op==5) {
            cin >> x; //查询后继
            cout << query_suc(1, 1, n, l, r, x) << endl;
        }
    }
    return 0;
}
posted @ 2021-08-04 10:26  shuitiangong  阅读(46)  评论(0编辑  收藏  举报