树套树全家桶
树套树
概念
顾名思义,一个树套着另一个树(bushi)
eg. 维护一个线段树,并且对于每一个节用平衡树进行维护
树套树有很多种,外层的树可能有很多种,常见的是线段树与树状数组,内层的树最常见的是平衡树,也有可能是其他的
例题
T1
有以下的两种操作:
- \(1 \ pos \ x\) 将 \(pos\) 位置的数改成 \(x\)
- \(2\ l\ r\ x\) 查询 \(x\) 在 \([l,r]\) 小于 \(x\) 的最大值
分析
求区间内的小于 \(x\) 的最大值很容易想到用 \(multiset\) 中的 \(bound\) 来维护,但是如果这个区间不固定,那就只能再套一层线段树来维护了,对于任意一个区间,用线段树来凑就好了,对于查询,将所覆盖的区间的 \(multiset\) 进行调用,时间复杂度:\(O(log_n^2)\),对于修改:将包含这个点的左右区间的 \(multiset\) 先删去原来的数,再插入新的数,时间复杂度一样。
代码
真的很难调.....
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 50010;
const int M = N << 2;
const int INF = 1e9;
int n, m;
struct Tree
{
int l, r;
multiset<int> s;
}tr[M];
int w[N];
void build(int u, int l, int r)
{
tr[u] = {l, r};
tr[u].s.insert(-INF);
tr[u].s.insert(INF);
for(int i = l ; i <= r ; i ++ ) tr[u].s.insert(w[i]);
if(l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
void change(int u, int p, int x)
{
tr[u].s.erase(tr[u].s.find(w[p]));
tr[u].s.insert(x);
if(tr[u].l == tr[u].r) return;
int mid = tr[u].l + tr[u].r >> 1;
if(p <= mid) change(u << 1, p, x);
else change(u << 1 | 1, p, x);
}
int query(int u, int a, int b, int x)
{
if(tr[u].l >= a && tr[u].r <= b)
{
auto it = tr[u].s.lower_bound(x);
--it;
return *it;
}
int mid = tr[u].l + tr[u].r >> 1, res = -INF;
if(a <= mid) res = max(res, query(u << 1, a, b, x));
if(b > mid) res = max(res, query(u << 1 | 1, a, b, x));
return res;
}
signed main()
{
cin >> n >> m;
for(int i = 1 ; i <= n ; i ++ ) cin >> w[i];
build (1, 1, n);
while (m -- )
{
int op, a, b, x;
cin >> op;
if (op == 1)
{
cin >> a >> x;
change(1, a, x);
w[a] = x;
}
else
{
cin >> a >> b >> x;
cout << query(1, a, b, x) << endl;
}
}
return 0;
}
T2
- \(1 \ l \ r \ k\) 查询 \(x\) 在 \(l, r\) 中的排名
- \(2 \ l \ r\ k\) 查询 \(l, r\) 中排名为 \(k\) 的值
- \(3\ pos\ x\) 将 \(pos\) 的位置上的数改为 \(x\)
- \(4\ l \ r \ x\) 查询 \(x\) 在 \(l, r\) 中的前驱
- \(5\ l\ r\ x\) 查询 \(x\) 在 \(l, r\) 中的后继
分析
后两个操作就很简单用线段树来套 \(multiset\) 就好了, 但是 还有前两个操作,就不能偷懒了 \(QWQ\), 就只能用平衡树了,(因为平衡树的本质就是 动态 去维护一个区间), 那怎么维护呢?对于排名:其实就是算有几个数比 \(x\) 小,把所包含的区间的平衡树调用出来,然后加起来就好了, 时间复杂度 : \(O(log_n^2)\),对于第 \(k\) 小数:我们是没有办法照猫画虎, 不能将区间先划分出来,然后;累加在一起我们只能用 二分答案, 用上第一问的操作,如果\(mid\) 比 \(x\) 小就往大的去二分,否则就往小的去二分,时间复杂度:\(O(log_n^3)\) (学过权值线段树套线段树的别叫!),至于哪种平衡树? \(treap\) 行,\(splay\) 行, \(fhq-treap\) 也行....,剩下的就是普通操作了。
代码
又臭又长!!!
我吐啦!!!!!!!!
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 2e6 + 10;
const int INF = 1e9;
int n, m;
struct Node
{
int s[2], p, v;
int size;
void init(int _v, int _p)
{
v = _v, p = _p;
size = 1;
}
}tr[N];
int L[N], R[N], T[N], idx;
int w[N];
void pushup(int x)
{
tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + 1;
}
void rotate(int x)
{
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;
tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y, tr[y].p = x;
pushup(y), pushup(x);
}
void splay(int& root, int x, int k)
{
while(tr[x].p != k)
{
int y = tr[x].p, z = tr[y].p;
if(z != k)
if((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
else rotate(y);
rotate(x);
}
if(k == 0) root = x;
}
void insert(int& root, int v)
{
int u = root, p = 0;
while(u) p = u, u = tr[u].s[v > tr[u].v];
u = ++ idx;
if(p) tr[p].s[v > tr[p].v] = u;
tr[u].init(v, p);
splay(root, u, 0);
}
int get_k(int root, int v)
{
int u = root, res = 0;
while(u)
{
if(tr[u].v < v) res += tr[tr[u].s[0]].size + 1, u = tr[u].s[1];
else u = tr[u].s[0];
}
return res;
}
void update(int& root, int x, int y)
{
int u = root;
while(u)
{
if(tr[u].v == x) break;
if(tr[u].v < x) u = tr[u].s[1];
else u = tr[u].s[0];
}
splay(root, u, 0);
int l = tr[u].s[0], r = tr[u].s[1];
while(tr[l].s[1]) l = tr[l].s[1];
while(tr[r].s[0]) r = tr[r].s[0];
splay(root, l, 0), splay(root, r, l);
tr[r].s[0] = 0;
pushup(r), pushup(l);
insert(root, y);
}
void build(int u, int l, int r)
{
L[u] = l, R[u] = r;
insert(T[u], -INF);
insert(T[u], INF);
for(int i = l; i <= r; i ++ ) insert(T[u], w[i]);
if(l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
int query(int u, int a, int b, int x)
{
if(L[u] >= a && R[u] <= b) return get_k(T[u], x) - 1;
int mid = L[u] + R[u] >> 1, res = 0;
if(a <= mid) res += query(u << 1, a, b, x);
if(b > mid) res += query(u << 1 | 1, a, b, x);
return res;
}
void change(int u, int p, int x)
{
update(T[u], w[p], x);
if(L[u] == R[u]) return;
int mid = L[u] + R[u] >> 1;
if(p <= mid) change(u << 1, p, x);
else change(u << 1 | 1, p, x);
}
int get_pre(int root, int v)
{
int u = root, res = -INF;
while(u)
{
if(tr[u].v < v) res = max(res, tr[u].v), u = tr[u].s[1];
else u = tr[u].s[0];
}
return res;
}
int get_suc(int root, int v)
{
int u = root, res = INF;
while(u)
{
if(tr[u].v > v) res = min(res, tr[u].v), u = tr[u].s[0];
else u = tr[u].s[1];
}
return res;
}
int query_pre(int u, int a, int b, int x)
{
if(L[u] >= a && R[u] <= b) return get_pre(T[u], x);
int mid = L[u] + R[u] >> 1, res = -INF;
if(a <= mid) res = max(res, query_pre(u << 1, a, b, x));
if(b > mid) res = max(res, query_pre(u << 1 | 1, a, b, x));
return res;
}
int query_suc(int u, int a, int b, int x)
{
if(L[u] >= a && R[u] <= b) return get_suc(T[u], x);
int mid = L[u] + R[u] >> 1, res = INF;
if(a <= mid) res = min(res, query_suc(u << 1, a, b, x));
if(b > mid) res = min(res, query_suc(u << 1 | 1, a, b, x));
return res;
}
signed main()
{
cin >> n >> m;
for (int i = 1; i <= n; i ++ ) cin >> w[i];
build(1, 1, n);
while(m -- )
{
int op, a, b, x;
cin >> op;
if (op == 1)
{
cin >> a >> b >> x;
cout << query(1, a, b, x) + 1 << endl;
}
else if (op == 2)
{
cin >> a >> b >> x;
int l = 0, r = 1e8;
while (l < r)
{
int mid = l + r + 1 >> 1;
if(query(1, a, b, mid) + 1 <= x) l = mid;
else r = mid - 1;
}
cout << r << endl;
}
else if (op == 3)
{
cin >> a >> x;
change(1, a, x);
w[a] = x;
}
else if (op == 4)
{
cin >> a >> b >> x;
cout << query_pre(1, a, b, x) << endl;
}
else
{
cin >> a >> b >> x;
cout << query_suc(1, a, b, x) << endl;
}
}
return 0;
}
T3
\(1\ a\ b\ c\): 将 \(a\)到 \(b\) 中的每个位置都加长一个数 \(c\)
\(2\ a\ b\ c\): 询问 \(a\) 到 \(b\) 位置中的第 \(k\) 大数
请注意,这个位置上可以放很多的数
分析
用线段树套平衡树好像不太好做的样子~主要是线段树套平衡树的时间复杂度是 $O(nlog_n^3),时间太慢,我们就做一个 权值线段树(又称主席树)套线段树!
普通线段树是以下标为端点,维护下标,对于权值线段树,我们以数值为端点,我们就维护下标,但怎么维护呢?答案是线段树bushi
对于加入,因为一个相同的权值是在权值线段树上的一个点,我们就只需要修改 $O(log_n) $ 个普通的线段树,对于每个普通线段树,就是将这段都加一,其实就是区间修改,用个懒标记就好了!时间复杂度 \(O(log_n^2)\)
再考虑查询第 \(k\) 大数,考虑在线段树上二分,因为是第 \(k\) 大数,所以先考虑大的那一边,那怎么判断下标在 \([a, b]\) 内,权值在 \([l, r]\) 内的数的个数呢?其实这个个数就是这个权值线段数上 \([l,r]\) 这个区间上的普通线段树的 \([a,b]\) 段的数之和,即区间求和,直接做就好了, 时间复杂度 \(O(log_n^2)\)
这样子,你的代码时间复杂度就可以吞掉一个 \(O(log_n)\), 成为 \(nlog_n^2\) 的优秀代码,但是,打起来要吐血呀!!!!
当你做完了这些,结果内存一算,哎呀,爆了\((T(n \times n)\), 能不爆吗?(bushi , 所以我们还要 动态开点, 开不开心~~(/(ㄒoㄒ)/)
代码
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 50010;
const int P = N * 17 * 17;
const int M = N * 4;
int n, m;
struct Tree
{
int l, r;
int sum, add;
}tr[P];
int L[M], R[M], T[M], idx;
struct Query
{
int op, a, b, c;
}q[N];
vector<int> nums;
int get(int x)
{
return lower_bound(nums.begin(), nums.end(), x) - nums.begin();
}
void build(int u, int l, int r)
{
L[u] = l;
R[u] = r;
T[u] = ++ idx;
if(l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
// 因为只要区间加,区间求和,我们就可以标记永久化,但我不会告诉你^v^
int intersection(int a, int b, int c, int d)
{
return min(b, d) - max(a, c) + 1;
}
void update(int u, int l, int r, int pl, int pr)
{
tr[u].sum += intersection(l, r, pl, pr);
if(l >= pl && r <= pr)
{
tr[u].add ++ ;
return;
}
int mid = l + r >> 1;
if(pl <= mid)
{
if(!tr[u].l) tr[u].l = ++ idx;
update(tr[u].l, l, mid, pl, pr);
}
if(pr > mid)
{
if(!tr[u].r) tr[u].r = ++ idx;
update(tr[u].r, mid + 1, r, pl, pr);
}
}
void change(int u, int a, int b, int c)
{
update(T[u], 1, n, a, b);
if(L[u] == R[u]) return;
int mid = L[u] + R[u] >> 1;
if(c <= mid) change(u << 1, a, b, c);
else change(u << 1 | 1, a, b, c);
}
int get_sum(int u, int l, int r, int pl, int pr, int add)
{
if(l >= pl && r <= pr) return tr[u].sum + (r - l + 1) * add;
int mid = l + r >> 1;
int res = 0;
add += tr[u].add;
if(pl <= mid)
{
if(tr[u].l) res += get_sum(tr[u].l, l, mid, pl, pr, add);
else res += intersection(l, mid, pl, pr) * add;
}
if(pr > mid)
{
if(tr[u].r) res += get_sum(tr[u].r, mid + 1, r, pl, pr, add);
else res += intersection(mid + 1, r, pl, pr) * add;
}
return res;
}
int query(int u, int a, int b, int c)
{
if(L[u] == R[u]) return R[u];
int mid = L[u] + R[u] >> 1;
int k = get_sum(T[u << 1 | 1], 1, n, a, b, 0);
if(k >= c) return query(u << 1 | 1, a, b, c);
return query(u << 1, a, b, c - k);
}
signed main()
{
cin >> n >> m;
for(int i = 0; i < m; i ++ )
{
cin >> q[i].op >> q[i].a >> q[i].b >> q[i].c;
if(q[i].op == 1) nums.push_back(q[i].c);
}
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
build(1, 0, nums.size() - 1);
for(int i = 0; i < m; i ++ )
{
int op = q[i].op, a = q[i].a, b = q[i].b, c = q[i].c;
if(op == 1) change(1, a, b, get(c));
else cout << nums[query(1, a, b, c)] << endl;
}
return 0;
}
题单
还没有看到呢...
如果有好题单,不介意的话可以给我,在线等,急,谢谢!

浙公网安备 33010602011771号