题解:洛谷 P3380 【模板】树套树

【题目来源】

洛谷:P3380 【模板】树套树 - 洛谷 (luogu.com.cn)

【题目描述】

您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:

  1. 查询 \(k\) 在区间内的排名;
  2. 查询区间内排名为 \(k\) 的值;
  3. 修改某一位置上的数值;
  4. 查询 \(k\) 在区间内的前驱(前驱定义为严格小于 \(x\),且最大的数,若不存在输出 -2147483647);
  5. 查询 \(k\) 在区间内的后继(后继定义为严格大于 \(x\),且最小的数,若不存在输出 2147483647)。

对于一组元素,一个数的排名被定义为严格比它小的元素个数加一,而排名为 \(k\) 的数被定义为“将元素从小到大排序后排在第 \(k\) 位的元素值”。

【输入】

第一行两个数 \(n,m\),表示长度为 \(n\) 的有序序列和 \(m\) 个操作。

第二行有 \(n\) 个数,表示有序序列。

下面有 \(m\) 行,\(opt\) 表示操作标号。

\(opt=1\),则为操作 \(1\),之后有三个数 \(l~r~k\),表示查询 \(k\) 在区间 \([l,r]\) 的排名。

\(opt=2\),则为操作 \(2\),之后有三个数 \(l~r~k\),表示查询区间 \([l,r]\) 内排名为 \(k\) 的数。

\(opt=3\),则为操作 \(3\),之后有两个数 \(pos~k\),表示将 \(pos\) 位置的数修改为 \(k\)

\(opt=4\),则为操作 \(4\),之后有三个数 \(l~r~k\),表示查询区间 \([l,r]\)\(k\) 的前驱。

\(opt=5\),则为操作 \(5\),之后有三个数 \(l~r~k\),表示查询区间 \([l,r]\)\(k\) 的后继。

【输出】

对于操作 \(1,2,4,5\),各输出一行,表示查询结果。

【输入样例】

9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5

【输出样例】

2
4
3
4
9

【算法标签】

《洛谷 P3380 树套树》 #线段树# #平衡树# #树套树#

【代码详解】

#include <bits/stdc++.h>
using namespace std;

#define N 50005
#define INF 2147483647
#define ls(x) tr[x].s[0]  // 左儿子
#define rs(x) tr[x].s[1]  // 右儿子
#define lc u<<1
#define rc u<<1|1

int root[N * 4];  // 记录每个线段树节点对应的splay树的根

// Splay树节点结构
struct Node
{
    int s[2], p;  // 左右儿子,父节点
    int v, siz;  // 节点值,子树大小
    void init(int p1, int v1)
    {
        p = p1;
        v = v1;
        siz = 1;
    }
}tr[N * 40];  // 每个线段树节点对应一个Splay树

int n, m, w[N], idx;  // n: 数组长度,m: 操作数,w: 原始数组,idx: 节点计数器

// 更新节点信息
inline void pushup(int x)
{
    tr[x].siz = tr[ls(x)].siz + tr[rs(x)].siz + 1;
}

// Splay旋转操作
inline void rotate(int x)
{
    int y = tr[x].p, z = tr[y].p;
    int k = tr[y].s[1] == x;
    tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
    tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
    tr[x].s[k ^ 1] = y, tr[y].p = x;
    pushup(y), pushup(x);
}

// Splay操作,将x旋转到k的子节点
inline void splay(int &root, int x, int k)
{
    while (tr[x].p != k)
    {
        int y = tr[x].p, z = tr[y].p;
        if (z != k)
        {
            if ((rs(y) == x) ^ (rs(z) == y))
            {
                rotate(x);
            }
            else
            {
                rotate(y);
            }
        }
        rotate(x);
    }
    if (!k)
    {
        root = x;
    }
}

// 向Splay树中插入值v
inline void insert(int &root, int v)
{
    int u = root, p = 0;
    while (u)
    {
        p = u, u = tr[u].s[v > tr[u].v];
    }
    u = ++idx;
    tr[p].s[v > tr[p].v] = u;
    tr[u].init(p, v);
    splay(root, u, 0);
}

// 构建线段树,每个节点维护一个Splay树
void build(int u, int l, int r)
{
    insert(root[u], -INF), insert(root[u], INF);  // 插入哨兵节点
    for (int i = l; i <= r; i++)
    {
        insert(root[u], w[i]);  // 插入区间内的所有值
    }
    if (l == r) return;
    int mid = l + r >> 1;
    build(lc, l, mid);
    build(rc, mid + 1, r);
}

// 在Splay树中查询小于v的数的个数
inline int getrank(int root, int v)
{
    int u = root, res = 0;
    while (u)
    {
        if (tr[u].v < v)
        {
            res += tr[ls(u)].siz + 1, u = rs(u);
        }
        else
        {
            u = ls(u);
        }
    }
    return res;
}

// 查询区间[l,r]中比v小的数的个数
int queryrank(int u, int l, int r, int x, int y, int v)
{
    if (x <= l && r <= y)
    {
        return getrank(root[u], v) - 1;  // 减去哨兵
    }
    int mid = l + r >> 1, res = 0;
    if (x <= mid)
    {
        res += queryrank(lc, l, mid, x, y, v);
    }
    if (y > mid)
    {
        res += queryrank(rc, mid + 1, r, x, y, v);
    }
    return res;
}

// 查询区间[l,r]中第k小的数
int queryval(int u, int x, int y, int k)
{
    int l = 0, r = 1e8, ans;  // 二分答案
    while (l <= r)
    {
        int mid = l + r >> 1;
        if (queryrank(1, 1, n, x, y, mid) + 1 <= k)
        {
            l = mid + 1, ans = mid;
        }
        else
        {
            r = mid - 1;
        }
    }
    return ans;
}

// 从Splay树中删除值v
inline void del(int &root, int v)
{
    int u = root;
    while (u)
    {
        if (tr[u].v == v) break;
        if (tr[u].v < v)
        {
            u = rs(u);
        }
        else
        {
            u = ls(u);
        }
    }
    splay(root, u, 0);
    int l = ls(u), r = rs(u);
    while (rs(l))
    {
        l = rs(l);
    }
    while (ls(r))
    {
        r = ls(r);
    }
    splay(root, l, 0);
    splay(root, r, l);
    ls(r) = 0;
    splay(root, r, 0);
}

// 修改位置pos的值为v
void change(int u, int l, int r, int pos, int v)
{
    del(root[u], w[pos]);  // 删除旧值
    insert(root[u], v);  // 插入新值
    if (l == r) return;
    int mid = l + r >> 1;
    if (pos <= mid)
    {
        change(lc, l, mid, pos, v);
    }
    else
    {
        change(rc, mid + 1, r, pos, v);
    }
}

// 在Splay树中查询小于v的最大值
inline int getpre(int root, int v)
{
    int u = root, res = -INF;
    while (u)
    {
        if (tr[u].v < v)
        {
            res = tr[u].v, u = rs(u);
        }
        else
        {
            u = ls(u);
        }
    }
    return res;
}

// 查询区间[l,r]中比v小的最大值
int querypre(int u, int l, int r, int x, int y, int v)
{
    if (x <= l && r <= y)
    {
        return getpre(root[u], v);
    }
    int mid = l + r >> 1, res = -INF;
    if (x <= mid)
    {
        res = max(res, querypre(lc, l, mid, x, y, v));
    }
    if (y > mid)
    {
        res = max(res, querypre(rc, mid + 1, r, x, y, v));
    }
    return res;
}

// 在Splay树中查询大于v的最小值
inline int getnxt(int root, int v)
{
    int u = root, res = INF;
    while (u)
    {
        if (tr[u].v > v)
        {
            res = tr[u].v, u = ls(u);
        }
        else
        {
            u = rs(u);
        }
    }
    return res;
}

// 查询区间[l,r]中比v大的最小值
int querynxt(int u, int l, int r, int x, int y, int v)
{
    if (x <= l && r <= y)
    {
        return getnxt(root[u], v);
    }
    int mid = l + r >> 1, res = INF;
    if (x <= mid)
    {
        res = min(res, querynxt(lc, l, mid, x, y, v));
    }
    if (y > mid)
    {
        res = min(res, querynxt(rc, mid + 1, r, x, y, v));
    }
    return res;
}

int main()
{
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
    {
        cin >> w[i];
    }
    build(1, 1, n);  // 建树
  
    for (int i = 1; i <= m; i++)
    {
        int op, l, r, k, pos;
        cin >> op;
        if (op == 1)  // 查询区间[l,r]中k的排名
        {
            cin >> l >> r >> k;
            cout << queryrank(1, 1, n, l, r, k) + 1 << endl;
        }
        else if (op == 2)  // 查询区间[l,r]中第k小的数
        {
            cin >> l >> r >> k;
            cout << queryval(1, l, r, k) << endl;
        }
        else if (op == 3)  // 修改位置pos的值为k
        {
            cin >> pos >> k;
            change(1, 1, n, pos, k);
            w[pos] = k;
        }
        else if (op == 4)  // 查询区间[l,r]中比k小的最大值
        {
            cin >> l >> r >> k;
            cout << querypre(1, 1, n, l, r, k) << endl;
        }
        else  // 查询区间[l,r]中比k大的最小值
        {
            cin >> l >> r >> k;
            cout << querynxt(1, 1, n, l, r, k) << endl;
        }
    }
    return 0;
}

【运行结果】

9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
2
3 4 10
2 1 4 3
4
1 2 5 9
3
4 3 9 5
4
5 2 8 5
9
posted @ 2026-02-20 21:13  团爸讲算法  阅读(2)  评论(0)    收藏  举报