可持久化 01-trie 简记

本文略过了 trie 和 可持久化的介绍,如果没学过请先自学。

在求给定一个值 \(k\) 与区间中某些值的异或最大值时,可以考虑使用在线的数据结构可持久化 01-trie 来维护。

01-trie

01-trie 本身是用以求异或最大值的数据结构。

考虑板子题:给定 \(n\) 个数,\(m\) 次询问 \(k\) 与某个数异或的最大值。

我们肯定不能直接拿 \(k\) 与每个数异或,考虑把 \(k\) 拆成二进制,对于每一位单独做。
从高位到低位贪心。如果 \(k\) 的这一位是 \(1\) 那么我们就要尽可能的让选到的 \(a_i\) 这一位为 \(0\);反之要让 \(a_i\) 这一位为 \(1\)。这样我们取到的 \(a_i\) 一定是使 \(k\oplus a_i\) 最大的值。
这样我们就需要一种可以保留所有二进制串信息且占用空间尽可能小的数据结构,而 trie 恰好符合这一点。
01-trie 就是字符集为 \(\{0, 1\}\) 的 trie。
我们把每个数转成二进制从高位开始加入到 01-trie 中,每次根据 \(k\) 二进制位是 0/1 来贪心地更新答案即可。

复杂度 \(O(n\log v+m\log v)\)\(v\) 是值域。

可持久化思想

如果从一个数变成一个区间与 \(k\) 的异或最大值,我们就要使用可持久化思想来升级 01-trie。

考虑升级版板子题:给定 \(n\) 个数,\(m\) 次询问 \(k\)\(a_l\sim a_r\) 异或后缀异或的最大值,或者插入一个数。
我们也不能对于每个异或后缀开一棵 01-trie 来统计,这样时间空间双双爆炸。考虑差分,转化成前缀信息来做。
\(s_i=\oplus_{j=1}^{i}\),询问时找 \(k\oplus s_n\oplus s_{i-1}\) 最大值即是 \(k\)\(a_i\) 后缀异或最大值。
用可持久化思想,每次插入一个数就在前一个版本基础上更新,把新的 \(s_n\) 插入到 01-trie 里面。
根据上面的转化,查询时就可以直接查 \(k\oplus s_n\)\(rt_l\sim rt_r\) 这些版本的 01-trie 的异或最大值。
与没有持久化的 01-trie 类似,考虑对于两个版本的 01-trie 该怎么判断这一位能不能取 \(1\)
对 01-trie 上每条边 0/1 额外维护一个 sz[] ,表示这条边 0/1 在历史版本中一共出现的次数,加入新节点时就在上一个版本对应边基础上 +1。查询时就可以作差求出需要的 0/1 边是否存在。

这样我们就用同样的单 \(\log\) 时间复杂度解决了这个问题,唯一区别是可持久化占用的空间从线性 \(O(n)\) 变成了单 \(\log\) \(O(n\log n)\)

代码实现

原题跳转:Luogu P4735 最大异或和

注意数组范围要开大 \(64\) 倍防止越界。

#include<bits/stdc++.h>
using namespace std;

const int maxn = 6e5 + 10;
int n, m, s[maxn << 1];
int tot, ch[maxn << 5][2], sz[maxn << 5], rt[maxn << 1];

void add(int u, int v, int t, int x) {
	if(t < 0) return;
	int i = (x >> t) & 1;
	ch[u][i] = ++tot, ch[u][i ^ 1] = ch[v][i ^ 1];
	sz[ch[u][i]] = sz[ch[v][i]] + 1;
	add(ch[u][i], ch[v][i], t - 1, x);
}
int ask(int u, int v, int t, int x) {
	if(t < 0) return 0;
	int i = (x >> t) & 1;
	if(sz[ch[u][i ^ 1]] - sz[ch[v][i ^ 1]]) {
		return (1 << t) + ask(ch[u][i ^ 1], ch[v][i ^ 1], t - 1, x);
	}
	else return ask(ch[u][i], ch[v][i], t - 1, x);
}

int main() {
	ios :: sync_with_stdio(false); cin.tie(0); cout.tie(0);
	
	cin >> n >> m;  
	rt[0] = ++tot, add(rt[0], 0, 25, 0);
	for(int i = 1; i <= n; i++) {
		int x; cin >> x;
		s[i] = s[i - 1] ^ x;
		rt[i] = ++tot, add(rt[i], rt[i - 1], 25, s[i]);
	}
	
	for(int i = 1; i <= m; i++) {
		char op; cin >> op;
		if(op == 'A') {
			int x; cin >> x; ++n;
			s[n] = s[n - 1] ^ x, rt[n] = ++tot; add(rt[n], rt[n - 1], 25, s[n]);
		}
		if(op == 'Q') {
			int l, r, x; cin >> l >> r >> x; l--, r--;
			if(l == 0) cout << ask(rt[r], 0, 25, x ^ s[n]) << endl;
			else cout << ask(rt[r], rt[l - 1], 25, x ^ s[n]) << endl;
		}
	}
	
	return 0;
} 

例题

Luogu P5795 异或运算

给定 \(x\)\(y\) 数组,询问给定两个范围 \(u,d\)\(l,r\)\(k\),求 \(x_u\sim x_d\) 分别异或 \(y_l\sim y_r\) 的第 \(k\) 大值。其中 \(x\) 长度不超过 \(10^3\)\(y\) 长度不超过 \(3\times 10^5\),询问不超过 \(5\times 10^2\)

注意这个题数据范围非常抽象。两维分开维护很困难,但是一维可以直接可持久化 01-trie。那么考虑对 \(x\) 暴力,对 \(y\) 用可持久化 01-trie 维护。每次询问,暴力统计 \(x\) 数组中的贡献。
具体地,把 \(x_i\) 的每一位单独拿出来,由于是第 \(k\) 大,统计这一位异或起来为 \(1\) 的个数,如果大于等于 \(k\),就都走二进制位不同的那一位并且统计这一位的贡献;小于 \(k\) 就走二进制位相同的那一项。所有贡献加起来即可。

代码实现
#include<bits/stdc++.h>
using namespace std;

const int maxn = 1e3 + 10, maxm = 3e5 + 10;
int n, m, q, x[maxn], y[maxm];

int tot, ch[maxm << 6][2], sz[maxm << 6], rt[maxm], posl[maxn], posr[maxn];
void add(int u, int v, int t, int k) {
	if(t < 0) return;
	int i = (k >> t) & 1;
	ch[u][i] = ++tot, ch[u][i ^ 1] = ch[v][i ^ 1];
	sz[ch[u][i]] = sz[ch[v][i]] + 1;
	return add(ch[u][i], ch[v][i], t - 1, k), void(0);
}
int ask(int xl, int xr, int yl, int yr, int k) {
	for(int j = xl; j <= xr; j++) posl[j] = rt[yl - 1], posr[j] = rt[yr];
	
	int res = 0;
	for(int t = 31; t >= 0; t--) {
		int cnt = 0;
		for(int j = xl; j <= xr; j++) {
			int i = (x[j] >> t) & 1;
			cnt += sz[ch[posr[j]][i ^ 1]] - sz[ch[posl[j]][i ^ 1]];
		}
		if(cnt >= k) {
			res |= (1 << t);
			for(int j = xl; j <= xr; j++) {
				int i = (x[j] >> t) & 1;
				posl[j] = ch[posl[j]][i ^ 1];
				posr[j] = ch[posr[j]][i ^ 1];
			}
		}
		else {
			k -= cnt;
			for(int j = xl; j <= xr; j++) {
				int i = (x[j] >> t) & 1;
				posl[j] = ch[posl[j]][i];
				posr[j] = ch[posr[j]][i];
			}
		}
		
	}
	
	return res;
}

int main() {
	ios :: sync_with_stdio(false); cin.tie(0); cout.tie(0);
	
	cin >> n >> m;
	for(int i = 1; i <= n; i++) cin >> x[i];
	for(int i = 1; i <= m; i++) cin >> y[i], rt[i] = ++tot, add(rt[i], rt[i - 1], 31, y[i]);
	
	cin >> q;
	for(int i = 1; i <= q; i++) {
		int u, d, l, r, k; cin >> u >> d >> l >> r >> k;
		cout << ask(u, d, l, r, k) << endl;
	}
	
	return 0;
}

lxl 上课讲到的题,如果学过可持久化 01-trie 就是宝宝题

Loj #6144. 「2017 山东三轮集训 Day6」C

给定 \(n\) 个数的数组,\(m\) 次询问,每次全局按位异或/按位与/按位或上 \(x\),或者查询 \([l,r]\)\(k\) 大。\(n,m\le5\times 10^4\)

操作难以直接维护,因为手玩会发现几个按位运算放在一起并不具备交换律,这意味着直接维护需要在线处理标记。但是单独按位异或操作是好维护的,因为异或上某一位的要么不变要么 \(0/1\) 互换,同时按位异或也是可以离线下来等到查询时一起处理的。

考虑记翻转标记 rev ,二进制位为 \(1\) 表示进行 \(0/1\) 翻转,为 \(0\) 时操作不会造成影响。每次要全局异或上 \(x\) 时就可以异或在 rev 上面,查询时就根据标记来判断往那边走。

接下来考虑按位与和按位或操作的维护。发现其都具有合并的性质:具体地,无论 \(0/1\) 按位与上 \(0\)都会得到 \(0\),无论 \(0/1\) 按位或上 \(1\) 都会得到 \(1\),相当于 01-trie 上把两条边缩成一条边;并且合并不可逆的,因为无论什么操作都不能再将这一位分开。同时由于其会影响 rev 标记,我们需要同步更新:按位与 \(0\)rev 这一位也变成 \(0\);按位或 \(1\)rev 这一位也变成 \(1\)。至于按位与 \(1\) 和按位或 \(0\) 不会造成任何影响。

对于具有合并性质的操作,可能改变 \(O(n\log n)\) 的信息,所以可以考虑直接暴力重构。一次合并会把所有数的这一位变得相同,我们用 vis 记录下二进制位是否经历过合并,对同一位的重复合并不会带来更多便利,没有意义。这样重构至多只会有 \(O(\log v)\) 次,这一部分总复杂度 \(O(n\log^2v)\)

对于经历过合并的二进制位,我们单独处理它的值。记 tag 表示二进制位被合并之后的值。一个很好的性质是 tag 的值就是每次操作的叠加:按位与 \(0\) 之后它们都变为 \(0\);按位或 \(1\) 之后它们都变为 \(1\);按位异或 \(1\) 之后它们都 \(0/1\) 翻转;其他操作后它们都不发生改变。所以我们只需要每次操作之后更新 tag 就好了。

对于查询,与上道题目类似。查询第 \(k\) 小就找异或上 rev 使这一位为 \(0\) 的数的个数,比 \(k\) 小就统计贡献,并减去个数接着找;对于经历过合并的直接统计贡献即可。

注意加数的时候要特判掉合并过的二进制位

代码实现
#include<bits/stdc++.h>
using namespace std;
#define int long long

const int maxn = 5e4 + 10;
int n, m, a[maxn];

int tot, ch[maxn << 6][2], sz[maxn << 6], rt[maxn];
int tag, rev, vis;
void add(int u, int v, int t, int k) {
	if(t < 0) return;
	int i = (k >> t) & 1;
	if((vis >> t) & 1) i = 0;//合并过的不考虑值 
	ch[u][i] = ++tot, ch[u][i ^ 1] = ch[v][i ^ 1];
	sz[ch[u][i]] = sz[ch[v][i]] + 1;
	return add(ch[u][i], ch[v][i], t - 1, k), void(0);
}
void rebuild() {
	memset(ch, 0, sizeof ch);
	memset(sz, 0, sizeof sz);
	memset(rt, 0, sizeof rt);
	tot = 0;
	for(int i = 1; i <= n; i++) rt[i] = ++tot, add(rt[i], rt[i - 1], 31, a[i] ^ rev);
	return rev = 0, void(0);
}
int ask(int u, int v, int t, int k) {
	if(t < 0) return 0;
	if((vis >> t) & 1) {
		return (tag & (1ll << t)) | ask(ch[u][0], ch[v][0], t - 1, k);
	}
	int i = (rev >> t) & 1, cnt = sz[ch[u][i]] - sz[ch[v][i]];
	if(k > cnt) return (1ll << t) | ask(ch[u][i ^ 1], ch[v][i ^ 1], t - 1, k - cnt);
	else return ask(ch[u][i], ch[v][i], t - 1, k);
}

signed main() {
	ios :: sync_with_stdio(false); cin.tie(0); cout.tie(0);
	
	cin >> n >> m;
	for(int i = 1; i <= n; i++) cin >> a[i], rt[i] = ++tot, add(rt[i], rt[i - 1], 31, a[i]);
	
	for(int i = 1; i <= m; i++) {
		string op; int x, l, r, k; bool flag = false;
		
		cin >> op;
		if(op == "Xor") {
			cin >> x;
			tag ^= x, rev ^= x;
		}
		if(op == "And") {
			cin >> x;
			tag &= x;
			for(int t = 31; t >= 0; t--) {
				if(!((x >> t) & 1)) {
					if(!((vis >> t) & 1)) vis |= (1ll << t), flag = true;
					rev -= rev & (1ll << t);// &0:rev第t位变成0 
				}
			}
		}
		if(op == "Or") {
			cin >> x;
			tag ^= x;
			for(int t = 31; t >= 0; t--) {
				if((x >> t) & 1) {
					if(!((vis >> t) & 1)) vis |= (1ll << t), flag = true;
					rev |= (1ll << t);// |1:rev第t位变成1 
				}
			}
		}
		if(op == "Ask") {
			cin >> l >> r >> k;
			cout << ask(rt[r], rt[l - 1], 31, k) << endl;
		}
		if(flag) rebuild();
	}
	
	return 0;
} 

但是只有40pts,还没调出来。

posted @ 2025-03-13 18:12  Ydoc770  阅读(81)  评论(0)    收藏  举报