2025.06.17 CW 模拟赛 D. 异或

D. 异或

没有意义.

题目描述

给定一个长度为 \(N\) 的非负整数序列 \(a_1, a_2, \ldots, a_N\) 和非负整数 \(x\)

求有多少个非空子序列 \(1 \leq b_1 < b_2 < \cdots < b_k \leq N\),满足对任意的 \((i, j)\) \((1 \leq i < j \leq k)\) 都有
\(a_{b_i} \oplus a_{b_j} \geq x\)。其中 \(\oplus\) 表示按位异或。

你只需要输出答案对 \(998244353\) 取模后的结果。


思路

key observation: \(\forall a \le b \le c, \min(a \oplus b, b \oplus c) \le a \oplus c\).

  • 证明

    从高到低考虑第一个不同的位, 即不为 0 0 0 或者 1 1 1. 由于 \(a \le b \le c\), 我们可以将该位分为 \(2\) 种情况进行讨论

    1. 0 0 1, 此时 \(a \oplus b\) 的这一位为 \(0\), \(b \oplus c\) 的这一位为 \(1\), \(a \oplus c\) 的这一位为 \(1\), 满足上式.
    2. 0 1 1, 此时 \(a \oplus b\) 的这一位为 \(1\), \(b \oplus c\) 的这一位为 \(0\), \(a \oplus c\) 的这一位为 \(1\), 同样满足上式.

    \(\square\).

有了这个性质, 我们就可以将原数组升序排序, 进而只考虑相邻两个数的异或的值了.

定义 \(f_i\) 表示以 \(a_i\) 结尾的合法方案数. 有一个朴素的 \(\mathcal{O}(n^2)\) 转移

\[f_i = \sum_{j = 1}^{i - 1} [a_i \oplus a_j \ge x] f_j \]

对于异或, 我们不难想到将所有数放到字典树上进行考虑, 转移的时候我们在字典树上转移即可, 具体实现可以看代码. 复杂度 \(\mathcal{O}(n \log V)\).

#include <iostream>
#include <algorithm>

using namespace std;

typedef long long ll;

#define Mod 998244353

void addon(int& x, int y) { (x += y) >= Mod and (x -= Mod); }

int n, f[300001];
ll x, a[300001];

class Trie {
private:
    int son[30000001][2], val[30000001], tot = 1;

public:
    void insert(ll num, int id) {
        int u = 1;
        for (int i = 60; ~i; --i) {
            int& ch = son[u][num >> i & 1];
            !ch and (ch = ++tot);
            addon(val[u = ch], f[id]);
        }
    }

    int query(ll num) {
        int u = 1, res = 0;
        for (int i = 60; ~i; --i) {
            if (x >> i & 1) {
                if (num >> i & 1) {
                    u = son[u][0];
                }
                else {
                    u = son[u][1];
                }
            }
            else {
                if (num >> i & 1) {
                    addon(res, val[son[u][0]]);
                    u = son[u][1];
                }
                else {
                    addon(res, val[son[u][1]]);
                    u = son[u][0];
                }
            }
        }
        addon(res, val[u]);
        return res;
    }

} trie;

void init() {
    scanf("%d %lld", &n, &x);
    for (int i = 1; i <= n; ++i) {
        scanf("%lld", a + i), f[i] = 1;
    }
    sort(a + 1, a + n + 1);
}

void calculate() {
    trie.insert(a[1], 1);
    for (int i = 2; i <= n; ++i) {
        addon(f[i], trie.query(a[i]));
        trie.insert(a[i], i);
    }

    int ans = 0;
    for (int i = 1; i <= n; ++i) {
        addon(ans, f[i]);
    }
    printf("%d\n", ans);
}

void solve() {
    init(), calculate();
}

int main() {
    solve();
    return 0;
}
posted @ 2025-06-17 18:53  Steven1013  阅读(26)  评论(0)    收藏  举报