Solution - P3384 【模板】重链剖分/树链剖分
糖丸了,写 Segment Tree 维护区间和不加区间长度。
思路
简单树剖板子,没什么好说的。(虽然蒟蒻第一次写树剖)
不是我是多久没写线段树了
代码
#include <bits/stdc++.h>
#define rint register int
#define rllong register long long
#define llong long long
#define N 100005
using namespace std;
int n, q, root, mod;
int val[N<<2], tag[N<<2];
int a[N], raw[N];
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
#define mid (l+r>>1)
#define len (r-l+1)
#define pushup(x) (val[x] = (val[ls(x)]+val[rs(x)])%mod)
inline void pushdown(rint x, rint l, rint r){
if(!tag[x]) return;
val[ls(x)] += tag[x]*(mid-l+1), tag[ls(x)] += tag[x];
val[rs(x)] += tag[x]*(r-mid) , tag[rs(x)] += tag[x];
tag[x] = 0;
return;
}
inline void build(rint x, rint l, rint r){
if(l == r){
val[x] = raw[l]%mod;
return;
}
build(ls(x), l, mid );
build(rs(x), mid+1, r);
pushup(x);
return;
}
inline void modify(rint x, rint L, rint R, rint l, rint r, rint k){
if(L <= l && R >= r){
val[x] += k*len, tag[x] += k;
return;
}
pushdown(x, l, r);
if(L <= mid) modify(ls(x), L, R, l, mid, k);
if(R > mid) modify(rs(x), L, R, mid+1, r, k);
pushup(x);
return;
}
inline int query(rint x, rint L, rint R, rint l, rint r){
if(L <= l && R >= r) return val[x];
pushdown(x, l, r);
rint ans = 0;
if(L <= mid) ans += query(ls(x), L, R, l, mid );
if(R > mid) ans += query(rs(x), L, R, mid+1, r);
return ans%mod;
}
int dfn[N], siz[N], dep[N], dfcnt;
int top[N], son[N], fa [N];
int to[N<<1], nxt[N<<1], head[N], gsiz = 1;
#define mkarc(u,v) (++gsiz,to[gsiz]=v,nxt[gsiz]=head[u],head[u]=gsiz)
inline void dfs1(rint u){
dep[u] = dep[fa[u]]+1, siz[u] = 1;
for(rint i = head[u]; i; i = nxt[i]){
rint v = to[i];
if(v == fa[u]) continue;
fa[v] = u;
dfs1(v);
if(siz[v] > siz[son[u]]) son[u] = v;
siz[u] += siz[v];
}
return;
}
inline void dfs2(rint u){
dfn[u] = ++dfcnt, raw[dfn[u]] = a[u];
if(!son[u]) return;
top[son[u]] = top[u];
dfs2(son[u]);
for(rint i = head[u]; i; i = nxt[i]){
rint v = to[i];
if(v == fa[u] || v == son[u]) continue;
top[v] = v;
dfs2(v);
}
return;
}
inline void modifylink(rint u, rint v, rint k){
while(top[u] != top[v]){
if(dep[top[u]] > dep[top[v]]) swap(u, v);
modify(1, dfn[top[v]], dfn[v], 1, n, k);
v = fa[top[v]];
}
if(dep[u] > dep[v]) swap(u, v);
modify(1, dfn[u], dfn[v], 1, n, k);
return;
}
inline int querylink(rint u, rint v){
rint ans = 0;
while(top[u] != top[v]){
if(dep[top[u]] > dep[top[v]]) swap(u, v);
ans = (ans+query(1, dfn[top[v]], dfn[v], 1, n))%mod;
v = fa[top[v]];
}
if(dep[u] > dep[v]) swap(u, v);
ans = (ans+query(1, dfn[u], dfn[v], 1, n));
return ans%mod;
}
int main(){
scanf("%d %d %d %d", &n, &q, &root, &mod);
for(rint i = 1; i <= n; ++i) scanf("%d", &a[i]);
for(rint i = 1; i < n; ++i){
rint u, v;
scanf("%d %d", &u, &v);
mkarc(u, v), mkarc(v, u);
}
top[root] = root;
dfs1(root);
dfs2(root);
build(1, 1, n);
while(q--){
rint op, u, v, k;
scanf("%d", &op);
if(op == 1){
scanf("%d %d %d", &u, &v, &k);
k %= mod;
modifylink(u, v, k);
}
else if(op == 2){
scanf("%d %d", &u, &v);
printf("%d\n", (querylink(u, v))%mod);
}
else if(op == 3){
scanf("%d %d", &u, &k);
modify(1, dfn[u], dfn[u]+siz[u]-1, 1, n, k);
}
else{
scanf("%d", &u);
printf("%d\n", query(1, dfn[u], dfn[u]+siz[u]-1, 1, n)%mod);
}
}
return 0;
}