P14509 树上求值 tree 题解
Description
对于两个权值序列 \(A(0)_0,A(0)_1,\dots,A(0)_{20}\) 与 \(A(1)_0,A(1)_1,\dots,A(1)_{20}\),记 \(x=\sum_{i=0}^{20} b_i2^{i}\),其中 \(b_i\in \{0,1\}\),定义 \(f(x)=\prod_{i=0}^{20} A(b_i)_i\)。
给定一颗树。有 \(T\) 组数据,每组数据给定树的根结点编号 \(r_i\),模数 \(m\) 与权值序列 \(A(0),A(1)\),你需要对每组数据都求出答案。
对于一颗根为 \(r_i\) 的有根树,每个结点 \(x\) 的深度 \(d_x\) 定义为 \(x\) 到 \(r_i\) 的简单路径上的结点数量。记 \(\operatorname{LCA*}(x,y)\) 为结点 \(x,y\) 最近公共祖先的编号。对于每个树上的结点 \(x\),你需要求出 \(s_x=\sum_{i=1}^n f(i+d_{\operatorname{LCA*}(i,x)})\) 对 \(m\) 取模后的结果。
\(1\le T\le 10\),\(1\le r_i\le n\le 2\times 10^5\),\(0\le A(0)_i,A(1)_i<m\),\(2\le m\le 10^9\)。
Solution
首先显然要枚举 \(\text{LCA}\),设枚举的 \(\text{LCA}\) 是 \(x\),则对于 \(x\) 的每个儿子 \(y\),都要求 \(\sum_{i\in\text{subtree}(y)}{f(i+dep_x)}\)。
考虑链怎么做。
按照深度从深到浅枚举 \(x\) 的话,我们需要支持:动态插入一个元素,全局减一,查询当前集合中 \(\sum f(x)\) 的值。
有一个比较容易得到的想法是从高位到低位建 01 trie,但是这样的话做全局减一时很难模拟退位操作。
注意到退位操作是从低位到高位,每次判断当前位是否为 \(0\),如果不是 \(0\),则给当前位减一,并停止。如果是 \(0\),则给当前位赋值为 \(1\),并往更高位枚举。
这启发我们从低位到高位建 trie,现在减一操作就等价于从根节点开始,每次交换当前节点的左右子树,并且往交换后的右子树(也就是原来的左子树)递归,在操作的过程中维护子树信息即可。
所以链就被做到了 \(O(n\log n)\),一般情况就使用线段树合并来合并每个子树的信息即可。
时间复杂度:\(O(n\log n)\)。
Code
#include <bits/stdc++.h>
// #define int int64_t
using i64 = int64_t;
const int kMaxN = 2e5 + 5;
int n, rt, mod, sgt_cnt;
int v[2][21];
int rrt[kMaxN], dep[kMaxN], res[kMaxN], ss[kMaxN];
std::vector<int> G[kMaxN];
int qpow(int bs, int64_t idx = mod - 2) {
int ret = 1;
for (; idx; idx >>= 1, bs = (int64_t)bs * bs % mod)
if (idx & 1)
ret = (int64_t)ret * bs % mod;
return ret;
}
inline int add(int x, int y) { return (x + y >= mod ? x + y - mod : x + y); }
inline int sub(int x, int y) { return (x >= y ? x - y : x - y + mod); }
inline void inc(int &x, int y) { (x += y) >= mod ? x -= mod : x; }
inline void dec(int &x, int y) { (x -= y) < 0 ? x += mod : x; }
struct Node {
int ls, rs, sum;
} t[kMaxN * 21];
void build() {
std::cin >> n;
for (int i = 1; i < n; ++i) {
int u, v;
std::cin >> u >> v;
G[u].emplace_back(v), G[v].emplace_back(u);
}
}
void pushup(int d, int x) {
t[x].sum = (1ll * v[0][d] * t[t[x].ls].sum + 1ll * v[1][d] * t[t[x].rs].sum) % mod;
}
void update(int d, int &x, int ql) {
if (!x) x = ++sgt_cnt;
if (d == 21) {
++t[x].sum;
} else {
update(d + 1, (~ql >> d & 1) ? t[x].ls : t[x].rs, ql);
pushup(d, x);
}
}
int merge(int d, int x, int y) {
if (!x || !y) return x + y;
if (d == 21) {
inc(t[x].sum, t[y].sum);
return x;
} else {
t[x].ls = merge(d + 1, t[x].ls, t[y].ls);
t[x].rs = merge(d + 1, t[x].rs, t[y].rs);
pushup(d, x);
return x;
}
}
void work(int d, int x) {
if (!x || d == 21) return;
std::swap(t[x].ls, t[x].rs), work(d + 1, t[x].rs);
pushup(d, x);
}
void dfs1(int u, int fa) {
dep[u] = dep[fa] + 1;
int val = 1;
for (int i = 0; i <= 20; ++i) val = 1ll * val * v[(u + dep[u]) >> i & 1][i] % mod;
inc(res[u], val);
for (auto v : G[u]) {
if (v == fa) continue;
dfs1(v, u);
work(0, rrt[v]), ss[v] = t[rrt[v]].sum;
rrt[u] = merge(0, rrt[u], rrt[v]);
}
update(0, rrt[u], u + dep[u]);
}
void dfs2(int u, int fa) {
for (auto v : G[u]) {
if (v != fa) inc(res[u], ss[v]);
}
for (auto v : G[u]) {
if (v == fa) continue;
dec(res[v], ss[v]);
dfs2(v, u);
}
}
void dfs3(int u, int fa) {
if (fa) inc(res[u], res[fa]);
for (auto v : G[u]) {
if (v == fa) continue;
dfs3(v, u);
}
}
void dickdreamer() {
std::cin >> rt >> mod;
for (int i = 1; i <= sgt_cnt; ++i) t[i] = {0, 0, 0};
sgt_cnt = 0;
for (int i = 0; i <= n; ++i) res[i] = ss[i] = rrt[i] = 0;
for (int i = 0; i <= 20; ++i) std::cin >> v[0][i];
for (int i = 0; i <= 20; ++i) std::cin >> v[1][i];
dfs1(rt, 0), dfs2(rt, 0), dfs3(rt, 0);
int64_t ans = 0;
for (int i = 1; i <= n; ++i) ans ^= 1ll * i * res[i];
std::cout << ans << '\n';
}
int32_t main() {
#ifdef ORZXKR
freopen("tree.in", "r", stdin);
freopen("tree.out", "w", stdout);
#endif
std::ios::sync_with_stdio(0), std::cin.tie(0), std::cout.tie(0);
int T = 1;
build();
std::cin >> T;
while (T--) dickdreamer();
// std::cerr << 1.0 * clock() / CLOCKS_PER_SEC << "s\n";
return 0;
}

浙公网安备 33010602011771号