洛谷P3380 【模板】树套树

这里是重工业科研场。

众所周知,主席树利用了前缀和的思想,\(Tree[root[r]].sum - Tree[root[l-1]].sum\) 等价于前缀和数组中的 \(sum[r] - sum[l-1]\)

这道题难点是:我们不能(无法想到)怎样动态修改主席树。

但是我们学过可以动态修改前缀和,用树状数组就可以。

所以这里的写法是用树状数组维护主席树节点的值。

具体怎么做呢,我们将 \(Tree[root[i]].sum\) 的值域进行变换。

\[[1,i] \to [i-lowbit(i)+1,i] \]

这样就可以用树状数组的方法求解啦~

讲解一下代码细节:

int qnum(int l,int r,int k,const int cnt0,const int cnt1){
	if(l==r) return l;
	int mid = (l+r)>>1,x = 0;
	for(int i=1;i<=cnt0;i++) x += T[T[tmp[0][i]].lson].sum;
	for(int i=1;i<=cnt1;i++) x -= T[T[tmp[1][i]].lson].sum;	
	if(k<=x){
		for(int i=1;i<=cnt0;i++) tmp[0][i] = T[tmp[0][i]].lson;
		for(int i=1;i<=cnt1;i++) tmp[1][i] = T[tmp[1][i]].lson;
		return qnum(l,mid,k,cnt0,cnt1);
	}
	else{
		for(int i=1;i<=cnt0;i++) tmp[0][i] = T[tmp[0][i]].rson;
		for(int i=1;i<=cnt1;i++) tmp[1][i] = T[tmp[1][i]].rson;
		return qnum(mid+1,r,k-x,cnt0,cnt1);
	}
}
int BITqnum(int l,int r,int k){
	l--;
	int cnt0 = 0,cnt1 = 0;
	while(r){
		tmp[0][++cnt0] = root[r];
		r -= r&-r;
	}
	while(l){
		tmp[1][++cnt1] = root[l];
		l -= l&-l;
	}
	return qnum(1,siz,k,cnt0,cnt1);
}

\(\textbf{qnum}\) 函数,在主席树上进行由排名查询该数的操作。

其中 \(\textbf{cnt0,cnt1,tmp[0/1]}\) 是很重要的变量。

因为你实际上要在很多棵主席树上跑当前值域的统计,按照树状数组,有 \(\log n\) 个主席树。

所以你得把当前所有的主席树节点编号全部存下来。

(因为线段树的编号你只知道 \(root\) 剩下的只能通过 \(lson,rson\) 进行访问了。)

所以最开始我们将所有的 \(root[i]\) 全部存下来,之后就每次递归前节点 \(u \to Tree[u].lson\) 就行了。

剩下的看代码吧。

\(\huge \mathscr{Code}\)

#include<bits/stdc++.h>
using namespace std;
const int N = 5e4 + 100;
int n, m, num[N * 2], cpy[N], tot, root[N], tmp[2][N], siz;
struct node {
	int lson, rson, sum;
}T[N * 160];
void pushup(int u) {
	T[u].sum = T[T[u].lson].sum + T[T[u].rson].sum;
}
void update(int& u, int l, int r, int q, int v) {
	if (!u) u = ++tot;
	if (l == r) {
		T[u].sum += v;
		return;
	}
	int mid = (l + r) >> 1;
	if (q <= mid) update(T[u].lson, l, mid, q, v);
	else update(T[u].rson, mid + 1, r, q, v);
	pushup(u);
}
int qnum(int l, int r, int k, const int cnt0, const int cnt1) {
	if (l == r) return l;
	int mid = (l + r) >> 1, x = 0;
	for (int i = 1; i <= cnt0; i++) x += T[T[tmp[0][i]].lson].sum;
	for (int i = 1; i <= cnt1; i++) x -= T[T[tmp[1][i]].lson].sum;
	if (k <= x) {
		for (int i = 1; i <= cnt0; i++) tmp[0][i] = T[tmp[0][i]].lson;
		for (int i = 1; i <= cnt1; i++) tmp[1][i] = T[tmp[1][i]].lson;
		return qnum(l, mid, k, cnt0, cnt1);
	}
	else {
		for (int i = 1; i <= cnt0; i++) tmp[0][i] = T[tmp[0][i]].rson;
		for (int i = 1; i <= cnt1; i++) tmp[1][i] = T[tmp[1][i]].rson;
		return qnum(mid + 1, r, k - x, cnt0, cnt1);
	}
}
int qrank(int l, int r, int k, const int cnt0, const int cnt1) {
	if (l == r) return 0;
	int mid = (l + r) >> 1, x = 0;
	if (k <= mid) {
		for (int i = 1; i <= cnt0; i++) tmp[0][i] = T[tmp[0][i]].lson;
		for (int i = 1; i <= cnt1; i++) tmp[1][i] = T[tmp[1][i]].lson;
		return qrank(l, mid, k, cnt0, cnt1);
	}
	else {
		for (int i = 1; i <= cnt0; i++) x += T[T[tmp[0][i]].lson].sum, tmp[0][i] = T[tmp[0][i]].rson;
		for (int i = 1; i <= cnt1; i++) x -= T[T[tmp[1][i]].lson].sum, tmp[1][i] = T[tmp[1][i]].rson;
		return x + qrank(mid + 1, r, k, cnt0, cnt1);
	}
}
void BITupdate(int x, int op) {
	int rec = x;
	while (x <= n) {
		update(root[x], 1, siz, cpy[rec], op);
		x += x & -x;
	}
}
int BITqnum(int l, int r, int k) {
	l--;
	int cnt0 = 0, cnt1 = 0;
	while (r) {
		tmp[0][++cnt0] = root[r];
		r -= r & -r;
	}
	while (l) {
		tmp[1][++cnt1] = root[l];
		l -= l & -l;
	}
	return qnum(1, siz, k, cnt0, cnt1);
}
int BITqrank(int l, int r, int k) {
	l--;
	int cnt0 = 0, cnt1 = 0;
	while (r) {
		tmp[0][++cnt0] = root[r];
		r -= r & -r;
	}
	while (l) {
		tmp[1][++cnt1] = root[l];
		l -= l & -l;
	}
	return qrank(1, siz, k, cnt0, cnt1) + 1;
}
int BITpre(int l, int r, int k) {
	int rk = BITqrank(l, r, k) - 1;
	if (!rk) return 0;
	return BITqnum(l, r, rk);
}
int BITnxt(int l, int r, int k) {
	int rk = BITqrank(l, r, k + 1);
	if (rk == r - l + 2) return siz + 1;
	return BITqnum(l, r, rk);
}
struct ques {
	int op, x, y, z;
}dat[N];
signed main() {
	cin >> n >> m;
	for (int i = 1; i <= n; i++) {
		cin >> cpy[i];
		num[++siz] = cpy[i];
	}
	for (int i = 1; i <= m; i++) {
		cin >> dat[i].op >> dat[i].x >> dat[i].y;
		if (dat[i].op != 3) cin >> dat[i].z;
		else num[++siz] = dat[i].y;
		if (dat[i].op == 4 or dat[i].op == 5) num[++siz] = dat[i].z;
	}
	sort(num + 1, num + siz + 1);
	siz = unique(num + 1, num + siz + 1) - num - 1;
	for (int i = 1; i <= n; i++) {
		cpy[i] = lower_bound(num + 1, num + siz + 1, cpy[i]) - num;
		BITupdate(i, 1);
	}
	num[0] = -2147483647, num[siz + 1] = 2147483647;
	for (int i = 1; i <= m; i++) {
		if (dat[i].op == 1) {
			dat[i].z = lower_bound(num + 1, num + siz + 1, dat[i].z) - num;
			cout << BITqrank(dat[i].x, dat[i].y, dat[i].z) << '\n';
		}
		else if (dat[i].op == 2) {
			cout << num[BITqnum(dat[i].x, dat[i].y, dat[i].z)] << '\n';
		}
		else if (dat[i].op == 3) {
			BITupdate(dat[i].x, -1);
			cpy[dat[i].x] = lower_bound(num + 1, num + siz + 1, dat[i].y) - num;
			BITupdate(dat[i].x, 1);
		}
		else if (dat[i].op == 4) {
			dat[i].z = lower_bound(num + 1, num + siz + 1, dat[i].z) - num;
			cout << num[BITpre(dat[i].x, dat[i].y, dat[i].z)] << '\n';
		}
		else {
			dat[i].z = lower_bound(num + 1, num + siz + 1, dat[i].z) - num;
			cout << num[BITnxt(dat[i].x, dat[i].y, dat[i].z)] << '\n';
		}
	}
	return 0;
}

posted @ 2025-07-22 17:07  OrangeRED  阅读(39)  评论(0)    收藏  举报