Loading

CF1616H.Keep XOR Low题解

题意

给你 \(n\) 个整数 \(a_1,a_2,\cdots,a_n\) 和一个整数 \(x\)

你需要求出 \(\{1,2,\cdots,n\}\) 的一个子集 \(S\),满足 \(S\) 中任意两个不同的元素 \(i,j\),满足 \(a_i~{\rm xor}~a_j\le x\)

求选取 \(S\) 的方案数,对 \(998244353\) 取模的结果。

\(1\le n\le 150000,0\le a_i,x< 2^{30}\)

题解

对于这种异或问题, 应该想到trie树。

考虑计数dp, 我们可以设 \(f_{i,j}\)表示前面位相同, 第\(i\)子树和第\(j\)子树的计数(要求\(i,j\)处于同一层)。

具体实现可以用搜索。

\(p[i] = 2 ^ i\)

\(solve1(d, u)\)

  • 表示在\(u\)子树中取不为空集
  • \(x\)的第\(d\)位为\(1\), 值为\(p[sz[lc]] + p[sz[rc]] - 2 + solve2(d - 1, lc, rc)\)
  • \(x\)的第\(d\)位为\(0\), 值为\((solve1(d - 1, lc) + solve1(d - 1, rc))\)

\(solve2(d, u1, u2)\)

  • 表示在\(u1,u2\)子树中取, 且两个子树都至少分别取\(1\)
  • \(x\)的第\(d\)位为\(1\), 设\(A = solve2(d - 1, lc1, rc2), B = solve2(d - 1, lc2, rc1)\)
    值为\(A * B + A * (p[sz[lc2]] + p[sz[rc1]] - 1) + B * (p[sz[lc1]] + p[sz[rc2]] - 1) + (p[sz[lc1]] - 1) * (p[sz[lc2]] - 1) + (p[sz[rc1]] - 1) * (p[sz[rc2]] - 1)\)
    之所以这里\(A * B\) 是因为如果\(lc1\)\(rc2\) 可行, \(lc1和rc1\)中任意一个都可以, 因为\(lc1和rc2\)比和\(rc1\)分开的更早。
  • \(x\)的第\(d\)位位\(0\), 值为\(solve2(d - 1, lc1, lc2) + solve2(d - 1, rc1, rc2)\)

代码

点击查看代码
#include <stdio.h>
#define LL long long
const int N = 150000;
const LL Mod = 998244353ll;

int n, x, a[N + 5], tot, ch[N * 31][2],sz[N  * 31];
LL p[N + 5];
void insert(int v) {
	int rt = 1; sz[rt]++;
	for(int i = 29; i >= 0; --i) {
		int p = (v >> i) & 1;
		if (!ch[rt][p]) ch[rt][p] = ++tot;
		sz[rt = ch[rt][p]]++;
	}
}

LL solve2(int d, int u1, int u2) {
	if (d < 0 || !u1 || !u2) return (p[sz[u1]] - 1) * (p[sz[u2]] - 1) % Mod;
	int lc1 = ch[u1][0], rc1 = ch[u1][1], lc2 = ch[u2][0], rc2 = ch[u2][1];
	if ((x >> d) & 1) {
		LL res = (p[sz[lc1]] - 1) * (p[sz[lc2]] - 1) % Mod + (p[sz[rc1]] - 1) * (p[sz[rc2]] - 1) % Mod; res %= Mod;
		LL A = solve2(d - 1, lc1, rc2), B = solve2(d - 1, lc2, rc1);
		(res += (A * B % Mod + A * (p[sz[lc2]] + p[sz[rc1]] - 1) % Mod + B * (p[sz[lc1]] + p[sz[rc2]] - 1) % Mod)%Mod) %= Mod;
		return res;
	}
	else return (solve2(d - 1, lc1, lc2) + solve2(d - 1, rc1, rc2)) % Mod;
}

LL solve1(int d, int u) {
	if (d < 0 || !u) return p[sz[u]] - 1;
	int lc = ch[u][0], rc = ch[u][1];
	if ((x >> d) & 1) 
		return (p[sz[lc]] + p[sz[rc]] - 2 + solve2(d - 1, lc, rc)) % Mod;
	else return (solve1(d - 1, lc) + solve1(d - 1, rc)) % Mod;
}

int main() {
	// freopen("t.in", "r", stdin);
	scanf("%d%d",&n, &x); ++tot; p[0] = 1;
	for(int i = 1; i <= n; ++i) p[i] = p[i - 1] * 2 % Mod;
	for(int i = 1, v; i <= n; ++i) {
		scanf("%d", &v);
		insert(v);
	}
	LL ans = solve1(29, 1);
	printf("%lld\n", solve1(29, 1));
	return 0;
}
/*

*/
posted @ 2022-12-03 18:54  Absolutey  阅读(32)  评论(0)    收藏  举报