洛谷P3369 【模板】普通平衡树 题解 权值线段树写法

题目链接:https://www.luogu.com.cn/problem/P3369

参考:https://www.cnblogs.com/trsins/p/17970745 (思路是自己想的)

特别需要注意的点

因为区间内存在负数,所以 mid 的取法应该写成:

mid = (l + r) >> 1

(或者 mid = l + (r - l) / 2 也行)

千万不要写成 mid = (l + r) / 2

原因可以参考下方的程序:

#include <bits/stdc++.h>
using namespace std;

int main() {

    int l = -5, r = -4;

    int mid = (l + r) / 2;
    cout << mid << endl;    // 输出-4

    mid = (l + r) >> 1;
    cout << mid << endl;    // 输出-5

    mid = l + (r - l) / 2;
    cout << mid << endl;    // 输出 -5

    return 0;
}

示例程序

#include <bits/stdc++.h>
using namespace std;
const int maxm = 1e7 + 5, inf = 1e7;

int n, tr[maxm], ls[maxm], rs[maxm], rt = 1, idx = 1;

void push_up(int u) {
    tr[u] = tr[ls[u]] + tr[rs[u]];
}

void update(int p, int c, int l, int r, int &u) {
    if (!u)
        u = ++idx;
    if (l == r) {
        tr[u] = max(0, tr[u] + c);
        return;
    }
    int mid = (l + r) >> 1;
    (p <= mid) ? update(p, c, l, mid, ls[u]) : update(p, c, mid+1, r, rs[u]);
    push_up(u);
}

// 查询有多少个数比p小
int query(int p, int l, int r, int u) {
    if (!u || l >= p)
        return 0;
    if (r < p)
        return tr[u];
    int mid = (l + r) >> 1;
    return query(p, l, mid, ls[u]) + query(p, mid+1, r, rs[u]);
}

// 查询排第k的数
int rnk(int k, int l, int r, int u) {
    if (l == r)
        return l;
    int mid = (l + r) >> 1;
    int cnt = tr[ls[u]];
    if (cnt >= k)
        return rnk(k, l, mid, ls[u]);
    return rnk(k-cnt, mid+1, r, rs[u]);
}

int main() {
    scanf("%d", &n);
    while (n--) {
        int op, x;
        scanf("%d%d", &op, &x);
        if (op == 1)
            update(x, 1, -inf, inf, rt);
        else if (op == 2)
            update(x, -1, -inf, inf, rt);
        else if (op == 3) {
            int cnt = query(x, -inf, inf, rt);
            printf("%d\n", cnt + 1);
        }
        else if (op == 4) {
            int rk = rnk(x, -inf, inf, rt);
            printf("%d\n", rk);
        }
        else if (op == 5) {
            int cnt = query(x, -inf, inf, rt);
            int ans = rnk(cnt, -inf, inf, rt);
            printf("%d\n", ans);
        }
        else {
            int cnt = query(x+1, -inf, inf, rt) + 1;
            int ans = rnk(cnt, -inf, inf, rt);
            printf("%d\n", ans);
        }
    }
    return 0;
}
posted @ 2025-03-07 21:36  quanjun  阅读(30)  评论(0)    收藏  举报