ABC451G. Minimum XOR Walk 题解 线性基 + 01 trie

题目链接:https://atcoder.jp/contests/abc451/tasks/abc451_g

首先,你需要解决

这道 “线性基” 的题:P14994 异或最短路和

其次,你需要会使用 01trie 实现以下功能:

  • 插入一个数字;
  • 判断入门有多少个数字 \(\oplus x \lt K+1\)

然后这题你就会了。

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 5;

int T, n, m, K, dis[maxn], a[33], idx;
bool vis[maxn];

struct Edge {
    int v, w;
};
vector<Edge> g[maxn];

struct Node {
    int son[2], sz;
} tr[maxn * 31];

void ins(int x) {
    int u = 1;
    for (int i = 30; i >= 0; i--) {
        int z = (x >> i) & 1;
        if (!tr[u].son[z])
            tr[u].son[z] = ++idx;
        u = tr[u].son[z];
        tr[u].sz++;
    }
}

int query(int x) { // < K + 1
    int u = 1, res = 0;
    for (int i = 30; i >= 0 && u; i--) {
        int z = (x >> i) & 1;
        int k = (K+1 >> i) & 1;
        if (k) {
            res += tr[ tr[u].son[z] ].sz;
        }
        u = tr[u].son[z^k];
    }
    return res;
}

void init() {
    for (int i = 0; i <= idx; i++) {
        tr[i].son[0] = tr[i].son[1] = tr[i].sz = 0;
    }
    idx = 1;
    for (int i = 1; i <= n; i++) {
        g[i].clear();
        vis[i] = false;
    }
    memset(a, 0, sizeof a);
}

void add(int x) {
    for (int i = 29; i >= 0; i--) {
        if ((x >> i) & 1) {
            if (!a[i]) {
                a[i] = x;
                return;
            }
            x ^= a[i];
        }
    }
}

void dfs(int u, int sum) {
    vis[u] = true;
    dis[u] = sum;
    for (auto [v, w] : g[u]) {
        if (!vis[v])
            dfs(v, sum ^ w);
        else {
            int tmp = dis[u] ^ dis[v] ^ w;
            if (tmp) {
                add(tmp);
            }
        }
    }
}

int cal(int x) {
    for (int i = 29; i >= 0; i--) {
        if (a[i])
            x = min(x, x ^ a[i]);
    }
    return x;
}

int main() {
    scanf("%d", &T);
    while (T--) {
        scanf("%d%d%d", &n, &m, &K);
        init();
        for (int i = 0, u, v, w; i < m; i++) {
            scanf("%d%d%d", &u, &v, &w);
            g[u].push_back({v, w});
            g[v].push_back({u, w});
        }
        dfs(1, 0);
        long long ans = 0;
        for (int i = 1; i <= n; i++) {
            dis[i] = cal(dis[i]);
        }
        for (int i = 1; i <= n; i++) {
            if (i > 1) {
                ans += query(dis[i]);
            }
            ins(dis[i]);
        }
        printf("%lld\n", ans);
    }
    return 0;
}
posted @ 2026-03-29 17:00  quanjun  阅读(0)  评论(0)    收藏  举报