洛谷P4211 [LNOI2014] LCA 题解 树链剖分+树上差分
题目链接:https://www.luogu.com.cn/problem/P4211
主要是 树上差分。
思路来自 紫钦大佬的博客
示例程序:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e4 + 5, mod = 201314;
int n, m, fa[maxn], ans[maxn];
vector<int> g[maxn];
struct Query {
int id, flag, p, z;
} qry[maxn*2];
int qid;
// 线段树
int tr[maxn<<2], lazy[maxn<<2];
#define lson l, mid, u<<1
#define rson mid+1, r, u<<1|1
void push_up(int u) {
tr[u] = (tr[u<<1] + tr[u<<1|1]) % mod;
}
void t_add(int lz, int l, int r, int u) {
tr[u] += 1ll * (r - l + 1) * lz % mod;
tr[u] %= mod;
lazy[u] = (lazy[u] + lz) % mod;
}
void push_down(int l, int r, int u) {
if (lazy[u]) {
int mid = (l + r) / 2;
t_add(lazy[u], l, mid, u<<1);
t_add(lazy[u], mid+1, r, u<<1|1);
lazy[u] = 0;
}
}
// 区间 +1
void add(int L, int R, int l, int r, int u) {
if (L <= l && r <= R) {
t_add(1, l, r, u);
return;
}
push_down(l, r, u);
int mid = (l + r) / 2;
if (L <= mid) add(L, R, lson);
if (R > mid) add(L, R, rson);
push_up(u);
}
// 查询区间和
int query(int L, int R, int l, int r, int u) {
if (L <= l && r <= R)
return tr[u];
int res = 0, mid = (l + r) / 2;
push_down(l, r, u);
if (L <= mid) res = (res + query(L, R, lson)) % mod;
if (R > mid) res = (res + query(L, R, rson)) % mod;
return res;
}
// 树链剖分
int dfn[maxn], id[maxn], idx, sz[maxn], dep[maxn], tp[maxn], son[maxn];
void dfs1(int u) {
sz[u] = 1;
for (auto v : g[u]) {
dep[v] = dep[u] + 1;
dfs1(v);
sz[u] += sz[v];
if (sz[v] > sz[ son[u] ])
son[u] = v;
}
}
void dfs2(int u, int top) {
dfn[u] = ++idx;
id[idx] = u;
tp[u] = top;
if (son[u])
dfs2(son[u], top);
for (auto v : g[u])
if (v != son[u])
dfs2(v, v);
}
void op_add(int u) {
while (u) {
add(dfn[tp[u]], dfn[u], 1, n, 1);
u = fa[ tp[u] ];
}
}
int op_query(int u) {
int res = 0;
while (u) {
res += query(dfn[tp[u]], dfn[u], 1, n, 1);
res %= mod;
u = fa[ tp[u] ];
}
return res;
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 2, p; i <= n; i++) {
scanf("%d", &p);
p++;
fa[i] = p;
g[p].push_back(i);
}
for (int i = 1; i <= m; i++) {
int l, r, z;
scanf("%d%d%d", &l, &r, &z);
l++, r++, z++;
if (l > 1) qry[qid++] = { i, -1, l-1, z };
qry[qid++] = { i, 1, r, z };
}
sort(qry, qry+qid, [](Query a, Query b) {
return a.p < b.p;
});
dfs1(1);
dfs2(1, 1);
for (int i = 0, q = 1; i < qid; i++) {
int id = qry[i].id, flag = qry[i].flag, p = qry[i].p, z = qry[i].z;
while (q <= p)
op_add(q++);
ans[id] = (ans[id] + flag * op_query(z) + mod) % mod;
}
for (int i = 1; i <= m; i++)
printf("%d\n", ans[i]);
return 0;
}
浙公网安备 33010602011771号