「题解」洛谷 P3401 洛谷树
描述
思路
预处理出每个点到根结点路径上经过的所有边权的异或和 \(s_i\),因为 \(a\operatorname{xor}a\operatorname{xor}b=b\),询问操作可以转化为 \(u\) 到 \(v\) 的路径上经过的所有点的 \(s_i\) 中任选两个异或产生的所有结果的和。
因为异或运算位与位之间是独立的,考虑建 \(\log_2 \max\{w\}=10\) 棵线段树,第 \(i\) 棵线段树的每个结点维护一段 \(\text{dfn}\) 值连续的结点区间中 \(s\) 值第 \(i\) 位为 \(1\) 的结点有多少个。
对于每次询问,树剖统计 \(u,v\) 路径上所有结点 \(s\) 值第 \(i\) 位为 \(1\) 的个数 \(w_i\),记 \(u,v\) 路径上经过的结点个数为 \(z\)。然后按位处理,依次考虑每一位的贡献。第 \(i\) 位带来的贡献即为 \(w_i\cdot (x-w_i)\cdot 2^i\),意思是每一个该位为 \(1\) 的结点 \(s\) 值与一个该位为 \(0\) 的结点的 \(s\) 值异或后该位会变成 \(1\)。
对于每次修改,记录这条边原来的边权 \(x\),记新边权为 \(y\),若 \(x\) 与 \(y\) 第 \(i\) 位不相同,则将深度较深的结点的子树所有结点的 \(s\) 值第 \(i\) 位反转(即 \(1\) 变 \(0\),\(0\) 变 \(1\))。
代码
#include <bits/stdc++.h>
using namespace std;
#define LOGW 12
#define MAXN 30010
int n, q, a[MAXN], tim, wid[MAXN], fa[MAXN], dis[MAXN], siz[MAXN], son[MAXN], dfn[MAXN], top[MAXN];
vector<pair<int, int> > g[MAXN];
inline void init(const int &u, const int &f) {
siz[u] = 1;
for (register int i = 0; i < g[u].size(); i++) {
const int v = g[u][i].first, w = g[u][i].second;
if (v == f) continue;
fa[v] = u;
wid[v] = w;
dis[v] = dis[u] + 1;
init(v, u);
siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}
inline void dfs(const int &u, const int &p, const int &s) {
top[u] = p;
dfn[u] = ++tim;
a[dfn[u]] = s;
if (!son[u]) return;
dfs(son[u], p, s ^ wid[son[u]]);
for (register int i = 0; i < g[u].size(); i++) {
const int v = g[u][i].first, w = g[u][i].second;
if (v == fa[u] || v == son[u]) continue;
dfs(v, v, s ^ w);
}
}
struct Segment_Tree {
int sum[MAXN << 2][LOGW]; // 区间内所有数中第 i 位为 1 的数的个数
bool lzy[MAXN << 2][LOGW];
inline bool GetIdx(const int &x, const int &k) { return x & (1ll << k); }
inline bool InRange(const int &l, const int &r, const int &L, const int &R) { return L <= l && R >= r; }
inline bool OutoRange(const int &l, const int &r, const int &L, const int &R) { return r < L || R < l; }
inline void pushup(const int &u) { for (register int i = 0; i <= 10; i++) sum[u][i] = sum[u << 1][i] + sum[u << 1 | 1][i]; }
inline void pushup_k(const int &u, const int &k) { sum[u][k] = sum[u << 1][k] + sum[u << 1 | 1][k]; }
inline void maketag(const int &u, const int &l, const int &r, const int &k) {
lzy[u][k] = !lzy[u][k]; sum[u][k] = r - l + 1 - sum[u][k];
}
inline void pushdown(const int &u, const int &l, const int &r, const int &k) {
if (!lzy[u][k]) return;
lzy[u][k] = 0;
const int mid = (l + r) >> 1;
maketag(u << 1, l, mid, k), maketag(u << 1 | 1, mid + 1, r, k);
}
inline void build(const int &u, const int &l, const int &r) {
if (l == r) { for (register int i = 0; i <= 10; i++) sum[u][i] = GetIdx(a[l], i); return; }
const int mid = (l + r) >> 1; build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r); pushup(u);
}
inline void update(const int &u, const int &l, const int &r, const int &L, const int &R, const int &k) {
// 区间内所有数的第 k 位取反
if (InRange(l, r, L, R)) { maketag(u, l, r, k); return; }
else if (!OutoRange(l, r, L, R)) {
pushdown(u, l, r, k);
const int mid = (l + r) >> 1;
update(u << 1, l, mid, L, R, k), update(u << 1 | 1, mid + 1, r, L, R, k);
pushup_k(u, k);
}
}
inline int query(const int &u, const int &l, const int &r, const int &L, const int &R, const int &k) {
// 区间内第 k 位为 1 的数的个数
if (InRange(l, r, L, R)) return sum[u][k];
else if (!OutoRange(l, r, L, R)) {
const int mid = (l + r) >> 1;
pushdown(u, l, r, k);
return query(u << 1, l, mid, L, R, k) + query(u << 1 | 1, mid + 1, r, L, R, k);
} else return 0;
}
} tree;
int tmp[12]; // 路径上第 i 位为 1 的数的个数
inline long long query(int u, int v) {
memset(tmp, 0, sizeof tmp);
int cnt = 0;
while (top[u] != top[v]) {
if (dis[top[u]] < dis[top[v]]) swap(u, v);
for (register int i = 0; i <= 10; i++) tmp[i] += tree.query(1, 1, n, dfn[top[u]], dfn[u], i);
cnt += dfn[u] - dfn[top[u]] + 1; u = fa[top[u]];
}
if (dis[u] < dis[v]) swap(u, v);
cnt += dfn[u] - dfn[v] + 1;
for (register int i = 0; i <= 10; i++) tmp[i] += tree.query(1, 1, n, dfn[v], dfn[u], i);
long long ans = 0;
for (register int i = 0; i <= 10; i++) ans += (1ll << i) * tmp[i] * (cnt - tmp[i]);
return ans;
}
inline void update(int u, int v, const int &w) {
if (dis[u] > dis[v]) swap(u, v);
for (register int i = 0; i <= 10; i++) {
bool x = tree.GetIdx(w, i), y = tree.GetIdx(wid[v], i);
if (x != y) tree.update(1, 1, n, dfn[v], dfn[v] + siz[v] - 1, i);
}
wid[v] = w;
}
signed main() {
cin >> n >> q;
for (register int i = 1; i < n; i++) {
int u, v, w;
cin >> u >> v >> w;
g[u].push_back({v, w}), g[v].push_back({u, w});
}
dis[1] = 1; init(1, 0); dfs(1, 1, 0); tree.build(1, 1, n);
while (q--) {
int opt, u, v, w;
cin >> opt >> u >> v;
if (opt == 1) cout << query(u, v) << endl;
else {
cin >> w;
update(u, v, w);
}
}
return 0;
}