2025.06.17 CW 模拟赛 D. 异或
D. 异或
没有意义.
题目描述
给定一个长度为 \(N\) 的非负整数序列 \(a_1, a_2, \ldots, a_N\) 和非负整数 \(x\)。
求有多少个非空子序列 \(1 \leq b_1 < b_2 < \cdots < b_k \leq N\),满足对任意的 \((i, j)\) \((1 \leq i < j \leq k)\) 都有
\(a_{b_i} \oplus a_{b_j} \geq x\)。其中 \(\oplus\) 表示按位异或。
你只需要输出答案对 \(998244353\) 取模后的结果。
思路
key observation: \(\forall a \le b \le c, \min(a \oplus b, b \oplus c) \le a \oplus c\).
-
证明
从高到低考虑第一个不同的位, 即不为
0 0 0或者1 1 1. 由于 \(a \le b \le c\), 我们可以将该位分为 \(2\) 种情况进行讨论0 0 1, 此时 \(a \oplus b\) 的这一位为 \(0\), \(b \oplus c\) 的这一位为 \(1\), \(a \oplus c\) 的这一位为 \(1\), 满足上式.0 1 1, 此时 \(a \oplus b\) 的这一位为 \(1\), \(b \oplus c\) 的这一位为 \(0\), \(a \oplus c\) 的这一位为 \(1\), 同样满足上式.
\(\square\).
有了这个性质, 我们就可以将原数组升序排序, 进而只考虑相邻两个数的异或的值了.
定义 \(f_i\) 表示以 \(a_i\) 结尾的合法方案数. 有一个朴素的 \(\mathcal{O}(n^2)\) 转移
\[f_i = \sum_{j = 1}^{i - 1} [a_i \oplus a_j \ge x] f_j
\]
对于异或, 我们不难想到将所有数放到字典树上进行考虑, 转移的时候我们在字典树上转移即可, 具体实现可以看代码. 复杂度 \(\mathcal{O}(n \log V)\).
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
#define Mod 998244353
void addon(int& x, int y) { (x += y) >= Mod and (x -= Mod); }
int n, f[300001];
ll x, a[300001];
class Trie {
private:
int son[30000001][2], val[30000001], tot = 1;
public:
void insert(ll num, int id) {
int u = 1;
for (int i = 60; ~i; --i) {
int& ch = son[u][num >> i & 1];
!ch and (ch = ++tot);
addon(val[u = ch], f[id]);
}
}
int query(ll num) {
int u = 1, res = 0;
for (int i = 60; ~i; --i) {
if (x >> i & 1) {
if (num >> i & 1) {
u = son[u][0];
}
else {
u = son[u][1];
}
}
else {
if (num >> i & 1) {
addon(res, val[son[u][0]]);
u = son[u][1];
}
else {
addon(res, val[son[u][1]]);
u = son[u][0];
}
}
}
addon(res, val[u]);
return res;
}
} trie;
void init() {
scanf("%d %lld", &n, &x);
for (int i = 1; i <= n; ++i) {
scanf("%lld", a + i), f[i] = 1;
}
sort(a + 1, a + n + 1);
}
void calculate() {
trie.insert(a[1], 1);
for (int i = 2; i <= n; ++i) {
addon(f[i], trie.query(a[i]));
trie.insert(a[i], i);
}
int ans = 0;
for (int i = 1; i <= n; ++i) {
addon(ans, f[i]);
}
printf("%d\n", ans);
}
void solve() {
init(), calculate();
}
int main() {
solve();
return 0;
}

浙公网安备 33010602011771号