Fork me on GitHub

树套树全家桶

树套树

概念

顾名思义,一个树套着另一个树(bushi)

eg. 维护一个线段树,并且对于每一个节用平衡树进行维护

树套树有很多种,外层的树可能有很多种,常见的是线段树与树状数组,内层的树最常见的是平衡树,也有可能是其他的

例题

T1

有以下的两种操作:

  • \(1 \ pos \ x\)\(pos\) 位置的数改成 \(x\)
  • \(2\ l\ r\ x\) 查询 \(x\)\([l,r]\) 小于 \(x\) 的最大值
分析

求区间内的小于 \(x\) 的最大值很容易想到用 \(multiset\) 中的 \(bound\) 来维护,但是如果这个区间不固定,那就只能再套一层线段树来维护了,对于任意一个区间,用线段树来凑就好了,对于查询,将所覆盖的区间的 \(multiset\) 进行调用,时间复杂度:\(O(log_n^2)\),对于修改:将包含这个点的左右区间的 \(multiset\) 先删去原来的数,再插入新的数,时间复杂度一样。

代码
真的很难调.....
#include <bits/stdc++.h>

#define int long long

using namespace std;

const int N = 50010;
const int M = N << 2;
const int INF = 1e9;

int n, m;
struct Tree
{
    int l, r;
    multiset<int> s;
}tr[M];
int w[N];

void build(int u, int l, int r)
{
    tr[u] = {l, r};
    tr[u].s.insert(-INF);
    tr[u].s.insert(INF);
    for(int i = l ; i <= r ; i ++ ) tr[u].s.insert(w[i]);
    if(l == r) return;
    int mid = l + r >> 1;
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
}

void change(int u, int p, int x)
{
    tr[u].s.erase(tr[u].s.find(w[p]));
    tr[u].s.insert(x);
    if(tr[u].l == tr[u].r) return;
    int mid = tr[u].l + tr[u].r >> 1;
    if(p <= mid) change(u << 1, p, x);
    else change(u << 1 | 1, p, x);
}

int query(int u, int a, int b, int x)
{
    if(tr[u].l >= a && tr[u].r <= b)
    {
        auto it = tr[u].s.lower_bound(x);
        --it;
        return *it;
    }
    int mid = tr[u].l + tr[u].r >> 1, res = -INF;
    if(a <= mid) res = max(res, query(u << 1, a, b, x));
    if(b > mid) res = max(res, query(u << 1 | 1, a, b, x));
    return res;
}

signed main()
{
    cin >> n >> m;
    for(int i = 1 ; i <= n ; i ++ ) cin >> w[i];
    build (1, 1, n);
    while (m -- )
    {
        int op, a, b, x;
        cin >> op;
        if (op == 1)
        {
            cin >> a >> x;
            change(1, a, x);
            w[a] = x;
        }
        else
        {
            cin >> a >> b >> x;
            cout << query(1, a, b, x) << endl;
        }
    }
    return 0;
}

T2

  • \(1 \ l \ r \ k\) 查询 \(x\)\(l, r\) 中的排名
  • \(2 \ l \ r\ k\) 查询 \(l, r\) 中排名为 \(k\) 的值
  • \(3\ pos\ x\)\(pos\) 的位置上的数改为 \(x\)
  • \(4\ l \ r \ x\) 查询 \(x\)\(l, r\) 中的前驱
  • \(5\ l\ r\ x\) 查询 \(x\)\(l, r\) 中的后继
分析

后两个操作就很简单用线段树来套 \(multiset\) 就好了, 但是 还有前两个操作,就不能偷懒了 \(QWQ\), 就只能用平衡树了,(因为平衡树的本质就是 动态 去维护一个区间), 那怎么维护呢?对于排名:其实就是算有几个数比 \(x\) 小,把所包含的区间的平衡树调用出来,然后加起来就好了, 时间复杂度 : \(O(log_n^2)\),对于第 \(k\) 小数:我们是没有办法照猫画虎, 不能将区间先划分出来,然后;累加在一起我们只能用 二分答案, 用上第一问的操作,如果\(mid\)\(x\) 小就往大的去二分,否则就往小的去二分,时间复杂度:\(O(log_n^3)\) (学过权值线段树套线段树的别叫!),至于哪种平衡树? \(treap\) 行,\(splay\) 行, \(fhq-treap\) 也行....,剩下的就是普通操作了。

代码
又臭又长!!!
我吐啦!!!!!!!!
#include <bits/stdc++.h>

#define int long long

using namespace std;

const int N = 2e6 + 10;
const int INF = 1e9;

int n, m;
struct Node
{
    int s[2], p, v;
    int size;
    void init(int _v, int _p)
    {
        v = _v, p = _p;
        size = 1;
    }
}tr[N];
int L[N], R[N], T[N], idx;
int w[N];

void pushup(int x)
{
    tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + 1;
}

void rotate(int x)
{
    int y = tr[x].p, z = tr[y].p;
    int k = tr[y].s[1] == x;
    tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
    tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
    tr[x].s[k ^ 1] = y, tr[y].p = x;
    pushup(y), pushup(x);
}

void splay(int& root, int x, int k)
{
    while(tr[x].p != k)
    {
        int y = tr[x].p, z = tr[y].p;
        if(z != k)
            if((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
            else rotate(y);
        rotate(x);
    }
    if(k == 0) root = x;
}

void insert(int& root, int v)
{
    int u = root, p = 0;
    while(u) p = u, u = tr[u].s[v > tr[u].v];
    u = ++ idx;
    if(p) tr[p].s[v > tr[p].v] = u;
    tr[u].init(v, p);
    splay(root, u, 0);
}

int get_k(int root, int v)
{
    int u = root, res = 0;
    while(u)
    {
        if(tr[u].v < v) res += tr[tr[u].s[0]].size + 1, u = tr[u].s[1];
        else u = tr[u].s[0];
    }
    return res;
}

void update(int& root, int x, int y)
{
    int u = root;
    while(u)
    {
        if(tr[u].v == x) break;
        if(tr[u].v < x) u = tr[u].s[1];
        else u = tr[u].s[0];
    }
    splay(root, u, 0);
    int l = tr[u].s[0], r = tr[u].s[1];
    while(tr[l].s[1]) l = tr[l].s[1];
    while(tr[r].s[0]) r = tr[r].s[0];
    splay(root, l, 0), splay(root, r, l);
    tr[r].s[0] = 0;
    pushup(r), pushup(l);
    insert(root, y);
}

void build(int u, int l, int r)
{
    L[u] = l, R[u] = r;
    insert(T[u], -INF);
    insert(T[u], INF);
    for(int i = l; i <= r; i ++ ) insert(T[u], w[i]);
    if(l == r) return;
    int mid = l + r >> 1;
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
}

int query(int u, int a, int b, int x)
{
    if(L[u] >= a && R[u] <= b) return get_k(T[u], x) - 1;
    int mid = L[u] + R[u] >> 1, res = 0;
    if(a <= mid) res += query(u << 1, a, b, x);
    if(b > mid) res += query(u << 1 | 1, a, b, x);
    return res;
}

void change(int u, int p, int x)
{
    update(T[u], w[p], x);
    if(L[u] == R[u]) return;
    int mid = L[u] + R[u] >> 1;
    if(p <= mid) change(u << 1, p, x);
    else change(u << 1 | 1, p, x);
}

int get_pre(int root, int v)
{
    int u = root, res = -INF;
    while(u)
    {
        if(tr[u].v < v) res = max(res, tr[u].v), u = tr[u].s[1];
        else u = tr[u].s[0];
    }
    return res;
}

int get_suc(int root, int v)
{
    int u = root, res = INF;
    while(u)
    {
        if(tr[u].v > v) res = min(res, tr[u].v), u = tr[u].s[0];
        else u = tr[u].s[1];
    }
    return res;
}

int query_pre(int u, int a, int b, int x)
{
    if(L[u] >= a && R[u] <= b) return get_pre(T[u], x);
    int mid = L[u] + R[u] >> 1, res = -INF;
    if(a <= mid) res = max(res, query_pre(u << 1, a, b, x));
    if(b > mid) res = max(res, query_pre(u << 1 | 1, a, b, x));
    return res;
}

int query_suc(int u, int a, int b, int x)
{
    if(L[u] >= a && R[u] <= b) return get_suc(T[u], x);
    int mid = L[u] + R[u] >> 1, res = INF;
    if(a <= mid) res = min(res, query_suc(u << 1, a, b, x));
    if(b > mid) res = min(res, query_suc(u << 1 | 1, a, b, x));
    return res;
}

signed main()
{
    cin >> n >> m;
    for (int i = 1; i <= n; i ++ ) cin >> w[i];
    build(1, 1, n);

    while(m -- )
    {
        int op, a, b, x;
        cin >> op;
        if (op == 1)
        {
            cin >> a >> b >> x;
            cout << query(1, a, b, x) + 1 << endl;
        }
        else if (op == 2)
        {
            cin >> a >> b >> x;
            int l = 0, r = 1e8;
            while (l < r)
            {
                int mid = l + r + 1 >> 1;
                if(query(1, a, b, mid) + 1 <= x) l = mid;
                else r = mid - 1;
            }
            cout << r << endl;
        }
        else if (op == 3)
        {
            cin >> a >> x;
            change(1, a, x);
            w[a] = x;
        }
        else if (op == 4)
        {
            cin >> a >> b >> x;
            cout << query_pre(1, a, b, x) << endl;
        }
        else
        {
            cin >> a >> b >> x;
            cout << query_suc(1, a, b, x) << endl;
        }
    }

    return 0;
}

T3

\(1\ a\ b\ c\): 将 \(a\)\(b\) 中的每个位置都加长一个数 \(c\)

\(2\ a\ b\ c\): 询问 \(a\)\(b\) 位置中的第 \(k\) 大数

请注意,这个位置上可以放很多的数

分析

用线段树套平衡树好像不太好做的样子~主要是线段树套平衡树的时间复杂度是 $O(nlog_n^3),时间太慢,我们就做一个 权值线段树(又称主席树)套线段树

普通线段树是以下标为端点,维护下标,对于权值线段树,我们以数值为端点,我们就维护下标,但怎么维护呢?答案是线段树bushi

对于加入,因为一个相同的权值是在权值线段树上的一个点,我们就只需要修改 $O(log_n) $ 个普通的线段树,对于每个普通线段树,就是将这段都加一,其实就是区间修改,用个懒标记就好了!时间复杂度 \(O(log_n^2)\)

再考虑查询第 \(k\) 大数,考虑在线段树上二分,因为是第 \(k\) 大数,所以先考虑大的那一边,那怎么判断下标在 \([a, b]\) 内,权值在 \([l, r]\) 内的数的个数呢?其实这个个数就是这个权值线段数上 \([l,r]\) 这个区间上的普通线段树的 \([a,b]\) 段的数之和,即区间求和,直接做就好了, 时间复杂度 \(O(log_n^2)\)

这样子,你的代码时间复杂度就可以吞掉一个 \(O(log_n)\), 成为 \(nlog_n^2\) 的优秀代码,但是,打起来要吐血呀!!!!

当你做完了这些,结果内存一算,哎呀,爆了\((T(n \times n)\), 能不爆吗?(bushi , 所以我们还要 动态开点, 开不开心~~(/(ㄒoㄒ)/

代码
#include <bits/stdc++.h>

#define int long long

using namespace std;

const int N = 50010;
const int P = N * 17 * 17;
const int M = N * 4;

int n, m;
struct Tree
{
    int l, r;
    int sum, add;
}tr[P];
int L[M], R[M], T[M], idx;
struct Query
{
    int op, a, b, c;
}q[N];
vector<int> nums;

int get(int x)
{
    return lower_bound(nums.begin(), nums.end(), x) - nums.begin();
}


void build(int u, int l, int r)
{
    L[u] = l;
    R[u] = r;
    T[u] = ++ idx;
    if(l == r) return;
    int mid = l + r >> 1;
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
}

// 因为只要区间加,区间求和,我们就可以标记永久化,但我不会告诉你^v^
int intersection(int a, int b, int c, int d)
{
    return min(b, d) - max(a, c) + 1;
}

void update(int u, int l, int r, int pl, int pr)
{
    tr[u].sum += intersection(l, r, pl, pr);
    if(l >= pl && r <= pr)
    {
        tr[u].add ++ ;
        return;
    }
    int mid = l + r >> 1;
    if(pl <= mid)
    {
        if(!tr[u].l) tr[u].l = ++ idx;
        update(tr[u].l, l, mid, pl, pr);
    }
    if(pr > mid)
    {
        if(!tr[u].r) tr[u].r = ++ idx;
        update(tr[u].r, mid + 1, r, pl, pr);
    }
}

void change(int u, int a, int b, int c)
{
    update(T[u], 1, n, a, b);
    if(L[u] == R[u]) return;
    int mid = L[u] + R[u] >> 1;
    if(c <= mid) change(u << 1, a, b, c);
    else change(u << 1 | 1, a, b, c);
}

int get_sum(int u, int l, int r, int pl, int pr, int add)
{
    if(l >= pl && r <= pr) return tr[u].sum + (r - l + 1) * add;
    int mid = l + r >> 1;
    int res = 0;
    add += tr[u].add;
    if(pl <= mid)
    {
        if(tr[u].l) res += get_sum(tr[u].l, l, mid, pl, pr, add);
        else res += intersection(l, mid, pl, pr) * add;
    }
    if(pr > mid)
    {
        if(tr[u].r) res += get_sum(tr[u].r, mid + 1, r, pl, pr, add);
        else res += intersection(mid + 1, r, pl, pr) * add;
    }
    return res;
}

int query(int u, int a, int b, int c)
{
    if(L[u] == R[u]) return R[u];
    int mid = L[u] + R[u] >> 1;
    int k = get_sum(T[u << 1 | 1], 1, n, a, b, 0);
    if(k >= c) return query(u << 1 | 1, a, b, c);
    return query(u << 1, a, b, c - k);
}

signed main()
{
    cin >> n >> m;
    for(int i = 0; i < m; i ++ )
    {
        cin >> q[i].op >> q[i].a >> q[i].b >> q[i].c;
        if(q[i].op == 1) nums.push_back(q[i].c);
    }
    sort(nums.begin(), nums.end());
    nums.erase(unique(nums.begin(), nums.end()), nums.end());

    build(1, 0, nums.size() - 1);

    for(int i = 0; i < m; i ++ )
    {
        int op = q[i].op, a = q[i].a, b = q[i].b, c = q[i].c;
        if(op == 1) change(1, a, b, get(c));
        else cout << nums[query(1, a, b, c)] << endl;
    }

    return 0;
}

题单

还没有看到呢...

如果有好题单,不介意的话可以给我,在线等,急,谢谢!

树套树真好玩 我~%?…,# *'☆&℃$︿★?

posted @ 2025-06-04 15:07  tony0530  阅读(42)  评论(0)    收藏  举报