『模板』树链剖分
这只是一个模板
【模板】轻重链剖分/树链剖分
题目描述
如题,已知一棵包含 \(N\) 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
-
1 x y z
,表示将树从 \(x\) 到 \(y\) 结点最短路径上所有节点的值都加上 \(z\)。 -
2 x y
,表示求树从 \(x\) 到 \(y\) 结点最短路径上所有节点的值之和。 -
3 x z
,表示将以 \(x\) 为根节点的子树内所有节点值都加上 \(z\)。 -
4 x
表示求以 \(x\) 为根节点的子树内所有节点值之和
输入格式
第一行包含 \(4\) 个正整数 \(N,M,R,P\),分别表示树的结点个数、操作个数、根节点序号和取模数(即所有的输出结果均对此取模)。
接下来一行包含 \(N\) 个非负整数,分别依次表示各个节点上初始的数值。
接下来 \(N-1\) 行每行包含两个整数 \(x,y\),表示点 \(x\) 和点 \(y\) 之间连有一条边(保证无环且连通)。
接下来 \(M\) 行每行包含若干个正整数,每行表示一个操作。
输出格式
输出包含若干行,分别依次表示每个操作 \(2\) 或操作 \(4\) 所得的结果(对 \(P\) 取模)。
样例 #1
样例输入 #1
5 5 2 24
7 3 7 8 0
1 2
1 5
3 1
4 1
3 4 2
3 2 2
4 5
1 5 1 3
2 1 3
样例输出 #1
2
21
提示
【数据规模】
对于 \(30\%\) 的数据: \(1 \leq N \leq 10\),\(1 \leq M \leq 10\);
对于 \(70\%\) 的数据: \(1 \leq N \leq {10}^3\),\(1 \leq M \leq {10}^3\);
对于 \(100\%\) 的数据: \(1\le N \leq {10}^5\),\(1\le M \leq {10}^5\),\(1\le R\le N\),\(1\le P \le 2^{31}-1\)。
【样例说明】
树的结构如下:
各个操作如下:
故输出应依次为 \(2\) 和 \(21\)。
\(Code\)
#include<bits/stdc++.h>
#include<cmath>
#include<queue>
#include<cstdio>
#include<cstring>
#include<iostream>
#define gc getchar
#include<algorithm>
#define reg register
#define ll long long
#define ls k<<1
#define rs k<<1|1
#define int long long
using namespace std;
const int M=105;
const int N=1e5+5;
//const int mod=998244353;
const int INF = 0x3f3f3f3f;
inline void print(int x) {if (x < 0) putchar('-'), x = -x; if(x > 9) print(x / 10); putchar(x % 10 + '0');}
inline int read() { int res = 0, f = 0; char ch = gc();for (; !isdigit(ch); ch = gc()) f |= (ch == '-'); for (; isdigit(ch); ch = gc()) res = (res << 1) + (res << 3) + (ch ^ '0'); return f ? -res : res;}
int n,m,r,p,mod,num,cnt;
struct node{int to,next;}e[N<<1];
struct nodee{int lz,sum,len;}tree[N<<2];
int w[N],fa[N],siz[N],dfn[N],son[N],pre[N],top[N],head[N],deep[N];
inline void add(int u,int v){e[++cnt]=(node){v,head[u]};head[u]=cnt;}//链式前向星存图
void dfs1(int now,int father)
{
deep[now]=deep[father]+1;//记录深度
siz[now]=1;//当前子树的初始节点数为1
fa[now]=father;//记录父亲
for(reg int i=head[now];i;i=e[i].next)//遍历相连的点
{
int v=e[i].to;
if(v==father) continue;//是父节点就跳过
dfs1(v,now);//dfs当前节点的儿子
siz[now]+=siz[v];//加上儿子数
if(siz[v]>siz[son[now]]) son[now]=v;//后代数多的点为重儿子
}
}
void dfs2(int now,int topp)//now为当前节点,topp为最顶端的节点
{
dfn[now]=++num;//记录dfs序
top[now]=topp;//指向所在链的顶端
pre[num]=now;//记录dfs序所指向的节点
if(!son[now]) return ;//没有重儿子就返回
dfs2(son[now],topp);//否则处理重儿子
for(reg int i=head[now];i;i=e[i].next)//遍历
{
int v=e[i].to;
if(v==fa[now] || v==son[now]) continue;//不是父节点或重儿子
dfs2(v,v);//以轻儿子为端点dfs下去
}
}
void build(int k,int l,int r)//建树
{
tree[k].len=r-l+1;
if(l==r)
{
tree[k].sum=w[pre[l]];
tree[k].lz=0;
return ;
}
int mid=(l+r)>>1;
build(ls,l,mid);
build(rs,mid+1,r);
tree[k].sum=(tree[ls].sum+tree[rs].sum)%mod;
}
void push_down(int k)
{
if(!tree[k].lz) return ;
(tree[ls].lz+=tree[k].lz)%=mod;
(tree[ls].sum+=tree[ls].len*tree[k].lz)%=mod;
(tree[rs].lz+=tree[k].lz)%=mod;
(tree[rs].sum+=tree[rs].len*tree[k].lz)%=mod;
tree[k].lz=0;
}
int query(int k,int l,int r,int L,int R)
{
int res=0;
if(l>=L && r<=R) return tree[k].sum;
push_down(k);
int mid=(l+r)>>1;
if(L<=mid)
(res+=query(ls,l,mid,L,R))%=mod;
if(R>mid)
(res+=query(rs,mid+1,r,L,R))%=mod;
return res;
}
void update(int k,int l,int r,int L,int R,int v)
{
if(l>=L&&r<=R)
{
(tree[k].lz+=v)%=mod;
(tree[k].sum+=v*tree[k].len)%=mod;
return ;
}
push_down(k);
int mid=(l+r)>>1;
if(L<=mid) update(ls,l,mid,L,R,v);
if(R>mid) update(rs,mid+1,r,L,R,v);
tree[k].sum=(tree[ls].sum+tree[rs].sum)%mod;
}
int find(int x,int y)
{
int ans=0;
int top1=top[x],top2=top[y]; //取链顶
while(top1!=top2)//不在同一条链上
{
if(deep[top1]<deep[top2])//保证top1的深度更深
{
swap(top1,top2);
swap(x,y);
}
(ans+=query(1,1,n,dfn[top1],dfn[x]))%=mod;//求区间内所有节点值的和
x=fa[top1],top1=top[x];//往上继续搜
}
if(deep[x]>deep[y]) swap(x,y);
(ans+=query(1,1,n,dfn[x],dfn[y]))%=mod;//求区间和
return ans;//返回
}
void change(int x,int y,int v)//同上
{
int top1=top[x],top2=top[y];
while(top1!=top2)
{
if(deep[top1]<deep[top2])
{
swap(top1,top2);
swap(x,y);
}
update(1,1,n,dfn[top1],dfn[x],v);
x=fa[top1],top1=top[x];
}
if(deep[x]>deep[y]) swap(x,y);
update(1,1,n,dfn[x],dfn[y],v);
}
signed main()
{
n=read(),m=read(),r=read(),mod=read();//输入节点个数,操作次数,根节点序号和模数
for(reg int i=1;i<=n;i++) w[i]=read();//各节点的初值
for(reg int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);//u,v之间连边
add(u,v);add(v,u);
}
dfs1(r,0);//得到重儿子编号,父节点,节点深度,子树大小
dfs2(r,r);//处理dfs序,每条链,他的顶端
build(1,1,n);//建树
for(reg int i=1;i<=m;i++)//m次操作
{
int op=read();
if(op==1)//从 X 到 Y 的最短路径上所有节点都加上x
{
int x=read(),y=read(),z=read();
change(x,y,z);
}
else if(op==2)//求 X 到 Y 的最短路径上所有节点的值的和
{
int x,y;
x=read(),y=read();
printf("%lld\n",find(x,y));
}
else if(op==3)//以 x 为跟的所有子树的节点值都加上z
{
int x,z;
x=read(),z=read();
update(1,1,n,dfn[x],dfn[x]+siz[x]-1,z);
}
if(op==4)//求以 x 为根节点的所有子树的值的和
{
int x=read();
printf("%lld\n",query(1,1,n,dfn[x],dfn[x]+siz[x]-1));
}
}
return 0;
}