树剖
先放自己手搓的
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 2e5;
int head[maxn << 2], to[maxn << 2], nxt[maxn << 2], cnt, siz[maxn << 2], fa[maxn << 2], dep[maxn << 2], son[maxn << 2], e, top[maxn << 2], seg[maxn << 2], num[maxn << 2], n, m, r, p, w[maxn << 2], res;
void insert(int u, int v) { nxt[++cnt] = head[u], head[u] = cnt, to[cnt] = v; }
struct tree {
int l, r, len, val, laz;
} tr[maxn << 2];
void dfs1(int u, int f, int depth, int maxson = -1) {
dep[u] = depth, siz[u] = 1, fa[u] = f;
for (int i = head[u]; i; i = nxt[i]) {
if (to[i] == f) continue;
dfs1(to[i], u, depth + 1), siz[u] += siz[to[i]];
if (siz[to[i]] > maxson) son[u] = to[i], maxson = siz[to[i]];
}
}
void dfs2(int u, int tp) {
top[u] = tp, seg[u] = ++e, num[e] = u;
if (!son[u]) return;
dfs2(son[u], tp);
for (int i = head[u]; i; i = nxt[i]) {
if (to[i] == son[u] || to[i] == fa[u]) continue;
dfs2(to[i], to[i]);
}
}
void build(int u, int l, int r) {
if (l == r)
tr[u] = tree { l, r, r - l + 1, w[num[l]] % p, 0 };
else
tr[u] = tree { l, r, r - l + 1, 0, 0 }, build(u << 1, l, (l + r) >> 1), build(u << 1 | 1, (l + r) / 2 + 1, r), tr[u].val = (tr[u << 1].val + tr[u << 1 | 1].val) % p;
}
void pushdown(int u) { tr[u << 1].val = (tr[u << 1].val + tr[u << 1].len * tr[u].laz) % p, tr[u << 1 | 1].val = (tr[u << 1 | 1].val + tr[u << 1 | 1].len * tr[u].laz) % p, tr[u << 1].laz += tr[u].laz, tr[u << 1 | 1].laz += tr[u].laz, tr[u].laz = 0; }
void modify(int u, int l, int r, int val) {
if (l <= tr[u].l && tr[u].r <= r)
tr[u].val += val * tr[u].len, tr[u].laz += val;
else {
if (tr[u].laz) pushdown(u);
if (l <= tr[u].l + tr[u].r >> 1) modify(u << 1, l, r, val);
if (r > tr[u].l + tr[u].r >> 1) modify(u << 1 | 1, l, r, val);
tr[u].val = (tr[u << 1].val + tr[u << 1 | 1].val) % p;
}
}
void query(int u, int l, int r) {
if (l <= tr[u].l && tr[u].r <= r) {
res += tr[u].val, res %= p;
} else {
if (tr[u].laz) pushdown(u);
if (l <= tr[u].l + tr[u].r >> 1) query(u << 1, l, r);
if (r > tr[u].l + tr[u].r >> 1) query(u << 1 | 1, l, r);
}
}
void treeadd(int x, int y, int val) {
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) swap(x, y);
modify(1, seg[top[x]], seg[x], val), x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
modify(1, seg[x], seg[y], val);
}
int treesum(int x, int y) {
int ans = 0;
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) swap(x, y);
res = 0, query(1, seg[top[x]], seg[x]), x = fa[top[x]], ans = (res + ans) % p;
}
if (dep[x] > dep[y]) swap(x, y);
res = 0, query(1, seg[x], seg[y]), ans += res;
return ans % p;
}
int op, x, y, z, a, b;
signed main() {
scanf("%lld %lld %lld %lld", &n, &m, &r, &p);
for (int i = 1; i <= n; i++) scanf("%lld ", &w[i]);
for (int i = 1; i < n; i++) scanf("%lld %lld", &a, &b), insert(a, b), insert(b, a);
dfs1(r, 0, 1), dfs2(r, r), build(1, 1, n);
for (int i = 1; i <= m; i++) {
scanf("%lld", &op);
if (op == 1)
scanf("%lld %lld %lld", &x, &y, &z), res = 0, treeadd(x, y, z % p);
else if (op == 2)
scanf("%lld %lld", &x, &y), res = 0, printf("%lld\n", treesum(x, y) % p);
else if (op == 3)
scanf("%lld %lld", &x, &y), res = 0, modify(1, seg[x], seg[x] + siz[x] - 1, y % p);
else if (op == 4)
scanf("%lld", &x), res = 0, query(1, seg[x], seg[x] + siz[x] - 1), printf("%lld\n", res % p);
}
return 0;
}
以下为题解
#include<iostream>
#include<cstdio>
#define int long long
using namespace std;
const int maxn=4e5+10;
struct edge{
int next,to;
}e[2*maxn];
struct Node{
int sum,lazy,l,r,ls,rs;
}node[2*maxn];
int rt,n,m,r,p,a[maxn],cnt,head[maxn],f[maxn],d[maxn],size[maxn],son[maxn],rk[maxn],top[maxn],id[maxn];
int mod(int a,int b)
{
return (a+b)%p;
}
void add_edge(int x,int y)
{
e[++cnt].next=head[x];
e[cnt].to=y;
head[x]=cnt;
}
void dfs1(int u,int fa,int depth)
{
f[u]=fa;
d[u]=depth;
size[u]=1;
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa)
continue;
dfs1(v,u,depth+1);
size[u]+=size[v];
if(size[v]>size[son[u]])
son[u]=v;
}
}
void dfs2(int u,int t)
{
top[u]=t;
id[u]=++cnt;
rk[cnt]=u;
if(!son[u])
return;
dfs2(son[u],t);
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(v!=son[u]&&v!=f[u])
dfs2(v,v);
}
}
void pushup(int x)
{
node[x].sum=(node[node[x].ls].sum+node[node[x].rs].sum+node[x].lazy*(node[x].r-node[x].l+1))%p;
}
void build(int li,int ri,int cur)
{
if(li==ri)
{
node[cur].l=node[cur].r=li;
node[cur].sum=a[rk[li]];
return;
}
int mid=(li+ri)>>1;
node[cur].ls=cnt++;
node[cur].rs=cnt++;
build(li,mid,node[cur].ls);
build(mid+1,ri,node[cur].rs);
node[cur].l=node[node[cur].ls].l;
node[cur].r=node[node[cur].rs].r;
pushup(cur);
}
void update(int li,int ri,int c,int cur)
{
if(li<=node[cur].l&&node[cur].r<=ri)
{
node[cur].sum=mod(node[cur].sum,c*(node[cur].r-node[cur].l+1));
node[cur].lazy=mod(node[cur].lazy,c);
return;
}
int mid=(node[cur].l+node[cur].r)>>1;
if(li<=mid)
update(li,ri,c,node[cur].ls);
if(mid<ri)
update(li,ri,c,node[cur].rs);
pushup(cur);
}
int query(int li,int ri,int cur)
{
if(li<=node[cur].l&&node[cur].r<=ri)
return node[cur].sum;
int tot=node[cur].lazy*(min(node[cur].r,ri)-max(node[cur].l,li)+1)%p;
int mid=(node[cur].l+node[cur].r)>>1;
if(li<=mid)
tot=mod(tot,query(li,ri,node[cur].ls));
if(mid<ri)
tot=mod(tot,query(li,ri,node[cur].rs));
return tot%p;
}
int sum(int x,int y)
{
int ans=0;
int fx=top[x],fy=top[y];
while(fx!=fy)
{
if(d[fx]>=d[fy])
{
ans=mod(ans,query(id[fx],id[x],rt));
x=f[fx],fx=top[x];
}
else
{
ans=mod(ans,query(id[fy],id[y],rt));
y=f[fy],fy=top[y];
}
}
if(id[x]<=id[y])
ans=mod(ans,query(id[x],id[y],rt));
else
ans=mod(ans,query(id[y],id[x],rt));
return ans%p;
}
void updates(int x,int y,int c)
{
int fx=top[x],fy=top[y];
while(fx!=fy)
{
if(d[fx]>=d[fy])
{
update(id[fx],id[x],c,rt);
x=f[fx],fx=top[x];
}
else
{
update(id[fy],id[y],c,rt);
y=f[fy],fy=top[y];
}
}
if(id[x]<=id[y])
update(id[x],id[y],c,rt);
else
update(id[y],id[x],c,rt);
}
signed main()
{
scanf("%d%d%d%d",&n,&m,&r,&p);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]),a[i]%=p;
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
add_edge(x,y);
add_edge(y,x);
}
cnt=0;
dfs1(r,0,1);
dfs2(r,r);
cnt=0;
rt=cnt++;
build(1,n,rt);
for(int i=1;i<=m;i++)
{
int op,x,y,z;
scanf("%lld",&op);
if(op==1)
{
scanf("%lld%lld%lld",&x,&y,&z);
updates(x,y,z);
}
else if(op==2)
{
scanf("%lld%lld",&x,&y);
printf("%lld\n",sum(x,y));
}
else if(op==3)
{
scanf("%lld%lld",&x,&z);
update(id[x],id[x]+size[x]-1,z,rt);
}
else if(op==4)
{
scanf("%lld",&x);
printf("%lld\n",query(id[x],id[x]+size[x]-1,rt));
}
}
return 0;
}
//还有呢
#include <bits/stdc++.h>
#include <cstdio>
using namespace std;
using ll = long long;
namespace IO {
template <class I>
inline void read(I& x) {
x = 0;
I f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
x *= f;
}
template <typename T, typename... Args>
inline void read(T& tmp, Args&... tmps) {
read(tmp);
read(tmps...);
}
template <class T>
inline void write(T x) {
if (x < 0) {
putchar('-');
x = -x;
}
if (x >= 10) {
write(x / 10);
}
putchar(x % 10 + '0');
}
}
using namespace IO;
#define int long long
#define mid ((l+r)>>1)
#define lson u<<1,l,mid
#define rson u<<1|1,mid+1,r
#define len (r-l+1)
const int maxn=200000+10;
int n,m,r,mod;
//见题意
int e,head[maxn],nxt[maxn],to[maxn],w[maxn],wt[maxn];
//链式前向星数组,w[]、wt[]初始点权数组
int tr[maxn<<2],laz[maxn<<2];
//线段树数组、lazy操作
int son[maxn],id[maxn],fa[maxn],cnt,dep[maxn],siz[maxn],top[maxn];
//son[]重儿子编号,id[]新编号,fa[]父亲节点,cnt dfs_clock/dfs序,dep[]深度,siz[]子树大小,top[]当前链顶端节点
int res=0;
void add_edge(int u, int v) {
nxt[++e] = head[u];
head[u] = e;
to[e] = v;
}
inline void pushdown(int u,int lenn){
laz[u<<1]+=laz[u];
laz[u<<1|1]+=laz[u];
tr[u<<1]+=laz[u]*(lenn-(lenn>>1));
tr[u<<1|1]+=laz[u]*(lenn>>1);
tr[u<<1]%=mod;
tr[u<<1|1]%=mod;
laz[u]=0;
}
inline void build(int u,int l,int r){
if(l==r){
tr[u]=wt[l];
if(tr[u]>mod) tr[u]%=mod;
return;
}
build(lson);
build(rson);
tr[u]=(tr[u<<1]+tr[u<<1|1])%mod;
}
inline void query(int u,int l,int r,int L,int R){
if(L<=l&&r<=R){res+=tr[u];res%=mod;return;}
else{
if(laz[u])pushdown(u,len);
if(L<=mid)query(lson,L,R);
if(R>mid)query(rson,L,R);
}
}
inline void update(int u,int l,int r,int L,int R,int k){
if(L<=l&&r<=R){
laz[u]+=k;
tr[u]+=k*len;
}
else{
if(laz[u])pushdown(u,len);
if(L<=mid)update(lson,L,R,k);
if(R>mid)update(rson,L,R,k);
tr[u]=(tr[u<<1]+tr[u<<1|1])%mod;
}
}
inline int qRange(int x,int y){
int ans=0;
while(top[x]!=top[y]){//当两个点不在同一条链上
if(dep[top[x]]<dep[top[y]])swap(x,y);//把x点改为所在链顶端的深度更深的那个点
res=0;
query(1,1,n,id[top[x]],id[x]);//ans加上x点到x所在链顶端 这一段区间的点权和
ans+=res;
ans%=mod;//按题意取模
x=fa[top[x]];//把x跳到x所在链顶端的那个点的上面一个点
}
//直到两个点处于一条链上
if(dep[x]>dep[y])swap(x,y);//把x点深度更深的那个点
res=0;
query(1,1,n,id[x],id[y]);//这时再加上此时两个点的区间和即可
ans+=res;
return ans%mod;
}
inline void updRange(int x,int y,int k){//同上
k%=mod;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(1,1,n,id[top[x]],id[x],k);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
update(1,1,n,id[x],id[y],k);
}
inline int qSon(int x){
res=0;
query(1,1,n,id[x],id[x]+siz[x]-1);//子树区间右端点为id[x]+siz[x]-1
return res;
}
inline void updSon(int x,int k){//同上
update(1,1,n,id[x],id[x]+siz[x]-1,k);
}
inline void dfs1(int x,int f,int deep){//x当前节点,f父亲,deep深度
dep[x]=deep;//标记每个点的深度
fa[x]=f;//标记每个点的父亲
siz[x]=1;//标记每个非叶子节点的子树大小
int maxson=-1;//记录重儿子的儿子数
for(int i=head[x];i;i=nxt[i]){
int y=to[i];
if(y==f)continue;//若为父亲则continue
dfs1(y,x,deep+1);//dfs其儿子
siz[x]+=siz[y];//把它的儿子数加到它身上
if(siz[y]>maxson)son[x]=y,maxson=siz[y];//标记每个非叶子节点的重儿子编号
}
}
inline void dfs2(int x,int topf){//x当前节点,topf当前链的最顶端的节点
id[x]=++cnt;//标记每个点的新编号
wt[cnt]=w[x];//把每个点的初始值赋到新编号上来
top[x]=topf;//这个点所在链的顶端
if(!son[x])return;//如果没有儿子则返回
dfs2(son[x],topf);//按先处理重儿子,再处理轻儿子的顺序递归处理
for(int i=head[x];i;i=nxt[i]){
int y=to[i];
if(y==fa[x]||y==son[x])continue;
dfs2(y,y);//对于每一个轻儿子都有一条从它自己开始的链
}
}
signed main(){
read(n);read(m);read(r);read(mod);
for(int i=1;i<=n;i++)read(w[i]);
for(int i=1;i<n;i++){
int a,b;
read(a);read(b);
add_edge(a,b);add_edge(b,a);
}
dfs1(r,0,1);
dfs2(r,r);
build(1,1,n);
while(m--){
int k,x,y,z;
read(k);
if(k==1){
read(x);read(y);read(z);
updRange(x,y,z);
}
else if(k==2){
read(x);read(y);
printf("%d\n",qRange(x,y));
}
else if(k==3){
read(x);read(y);
updSon(x,y);
}
else{
read(x);
printf("%d\n",qSon(x));
}
}
}

浙公网安备 33010602011771号