CF1616H.Keep XOR Low题解
题意
给你 \(n\) 个整数 \(a_1,a_2,\cdots,a_n\) 和一个整数 \(x\)。
你需要求出 \(\{1,2,\cdots,n\}\) 的一个子集 \(S\),满足 \(S\) 中任意两个不同的元素 \(i,j\),满足 \(a_i~{\rm xor}~a_j\le x\)。
求选取 \(S\) 的方案数,对 \(998244353\) 取模的结果。
\(1\le n\le 150000,0\le a_i,x< 2^{30}\)
题解
对于这种异或问题, 应该想到trie树。
考虑计数dp, 我们可以设 \(f_{i,j}\)表示前面位相同, 第\(i\)子树和第\(j\)子树的计数(要求\(i,j\)处于同一层)。
具体实现可以用搜索。
\(p[i] = 2 ^ i\)
\(solve1(d, u)\)
- 表示在\(u\)子树中取不为空集
- 当\(x\)的第\(d\)位为\(1\), 值为\(p[sz[lc]] + p[sz[rc]] - 2 + solve2(d - 1, lc, rc)\)
- 当\(x\)的第\(d\)位为\(0\), 值为\((solve1(d - 1, lc) + solve1(d - 1, rc))\)
\(solve2(d, u1, u2)\)
- 表示在\(u1,u2\)子树中取, 且两个子树都至少分别取\(1\)个
- 当\(x\)的第\(d\)位为\(1\), 设\(A = solve2(d - 1, lc1, rc2), B = solve2(d - 1, lc2, rc1)\)
值为\(A * B + A * (p[sz[lc2]] + p[sz[rc1]] - 1) + B * (p[sz[lc1]] + p[sz[rc2]] - 1) + (p[sz[lc1]] - 1) * (p[sz[lc2]] - 1) + (p[sz[rc1]] - 1) * (p[sz[rc2]] - 1)\)
之所以这里\(A * B\) 是因为如果\(lc1\)和\(rc2\) 可行, \(lc1和rc1\)中任意一个都可以, 因为\(lc1和rc2\)比和\(rc1\)分开的更早。 - 当\(x\)的第\(d\)位位\(0\), 值为\(solve2(d - 1, lc1, lc2) + solve2(d - 1, rc1, rc2)\)
代码
点击查看代码
#include <stdio.h>
#define LL long long
const int N = 150000;
const LL Mod = 998244353ll;
int n, x, a[N + 5], tot, ch[N * 31][2],sz[N * 31];
LL p[N + 5];
void insert(int v) {
int rt = 1; sz[rt]++;
for(int i = 29; i >= 0; --i) {
int p = (v >> i) & 1;
if (!ch[rt][p]) ch[rt][p] = ++tot;
sz[rt = ch[rt][p]]++;
}
}
LL solve2(int d, int u1, int u2) {
if (d < 0 || !u1 || !u2) return (p[sz[u1]] - 1) * (p[sz[u2]] - 1) % Mod;
int lc1 = ch[u1][0], rc1 = ch[u1][1], lc2 = ch[u2][0], rc2 = ch[u2][1];
if ((x >> d) & 1) {
LL res = (p[sz[lc1]] - 1) * (p[sz[lc2]] - 1) % Mod + (p[sz[rc1]] - 1) * (p[sz[rc2]] - 1) % Mod; res %= Mod;
LL A = solve2(d - 1, lc1, rc2), B = solve2(d - 1, lc2, rc1);
(res += (A * B % Mod + A * (p[sz[lc2]] + p[sz[rc1]] - 1) % Mod + B * (p[sz[lc1]] + p[sz[rc2]] - 1) % Mod)%Mod) %= Mod;
return res;
}
else return (solve2(d - 1, lc1, lc2) + solve2(d - 1, rc1, rc2)) % Mod;
}
LL solve1(int d, int u) {
if (d < 0 || !u) return p[sz[u]] - 1;
int lc = ch[u][0], rc = ch[u][1];
if ((x >> d) & 1)
return (p[sz[lc]] + p[sz[rc]] - 2 + solve2(d - 1, lc, rc)) % Mod;
else return (solve1(d - 1, lc) + solve1(d - 1, rc)) % Mod;
}
int main() {
// freopen("t.in", "r", stdin);
scanf("%d%d",&n, &x); ++tot; p[0] = 1;
for(int i = 1; i <= n; ++i) p[i] = p[i - 1] * 2 % Mod;
for(int i = 1, v; i <= n; ++i) {
scanf("%d", &v);
insert(v);
}
LL ans = solve1(29, 1);
printf("%lld\n", solve1(29, 1));
return 0;
}
/*
*/