『模板』树链剖分

这只是一个模板

【模板】轻重链剖分/树链剖分

题目描述

如题,已知一棵包含 \(N\) 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

  • 1 x y z,表示将树从 \(x\)\(y\) 结点最短路径上所有节点的值都加上 \(z\)

  • 2 x y,表示求树从 \(x\)\(y\) 结点最短路径上所有节点的值之和。

  • 3 x z,表示将以 \(x\) 为根节点的子树内所有节点值都加上 \(z\)

  • 4 x 表示求以 \(x\) 为根节点的子树内所有节点值之和

输入格式

第一行包含 \(4\) 个正整数 \(N,M,R,P\),分别表示树的结点个数、操作个数、根节点序号和取模数(即所有的输出结果均对此取模)。

接下来一行包含 \(N\) 个非负整数,分别依次表示各个节点上初始的数值。

接下来 \(N-1\) 行每行包含两个整数 \(x,y\),表示点 \(x\) 和点 \(y\) 之间连有一条边(保证无环且连通)。

接下来 \(M\) 行每行包含若干个正整数,每行表示一个操作。

输出格式

输出包含若干行,分别依次表示每个操作 \(2\) 或操作 \(4\) 所得的结果(\(P\) 取模)。

样例 #1

样例输入 #1

5 5 2 24
7 3 7 8 0 
1 2
1 5
3 1
4 1
3 4 2
3 2 2
4 5
1 5 1 3
2 1 3

样例输出 #1

2
21

提示

【数据规模】

对于 \(30\%\) 的数据: \(1 \leq N \leq 10\)\(1 \leq M \leq 10\)

对于 \(70\%\) 的数据: \(1 \leq N \leq {10}^3\)\(1 \leq M \leq {10}^3\)

对于 \(100\%\) 的数据: \(1\le N \leq {10}^5\)\(1\le M \leq {10}^5\)\(1\le R\le N\)\(1\le P \le 2^{31}-1\)

【样例说明】

树的结构如下:

各个操作如下:

故输出应依次为 \(2\)\(21\)

\(Code\)

#include<bits/stdc++.h>
#include<cmath>
#include<queue>
#include<cstdio>
#include<cstring>
#include<iostream>
#define gc getchar
#include<algorithm>
#define reg register
#define ll long long
#define ls k<<1
#define rs k<<1|1
#define int long long
using namespace std;
const int M=105;
const int N=1e5+5;
//const int mod=998244353;
const int INF = 0x3f3f3f3f;
inline void print(int x) {if (x < 0) putchar('-'), x = -x; if(x > 9) print(x / 10); putchar(x % 10 + '0');}
inline int read() { int res = 0, f = 0; char ch = gc();for (; !isdigit(ch); ch = gc()) f |= (ch == '-'); for (; isdigit(ch); ch = gc()) res = (res << 1) + (res << 3) + (ch ^ '0'); return f ? -res : res;}

int n,m,r,p,mod,num,cnt;
struct node{int to,next;}e[N<<1];
struct nodee{int lz,sum,len;}tree[N<<2];
int w[N],fa[N],siz[N],dfn[N],son[N],pre[N],top[N],head[N],deep[N];
inline void add(int u,int v){e[++cnt]=(node){v,head[u]};head[u]=cnt;}//链式前向星存图 

void dfs1(int now,int father)
{
	deep[now]=deep[father]+1;//记录深度 
	siz[now]=1;//当前子树的初始节点数为1 
	fa[now]=father;//记录父亲 
	for(reg int i=head[now];i;i=e[i].next)//遍历相连的点 
	{
		int v=e[i].to;
		if(v==father) continue;//是父节点就跳过 
		dfs1(v,now);//dfs当前节点的儿子 
		siz[now]+=siz[v];//加上儿子数 
		if(siz[v]>siz[son[now]]) son[now]=v;//后代数多的点为重儿子 
	}
 } 

void dfs2(int now,int topp)//now为当前节点,topp为最顶端的节点 
{
	dfn[now]=++num;//记录dfs序 
	top[now]=topp;//指向所在链的顶端 
	pre[num]=now;//记录dfs序所指向的节点 
	if(!son[now]) return ;//没有重儿子就返回 
	dfs2(son[now],topp);//否则处理重儿子 
	for(reg int i=head[now];i;i=e[i].next)//遍历 
	{
		int v=e[i].to;
		if(v==fa[now] || v==son[now]) continue;//不是父节点或重儿子 
		dfs2(v,v);//以轻儿子为端点dfs下去 
	}
}
 
void build(int k,int l,int r)//建树 
{
	tree[k].len=r-l+1;
	if(l==r)
	{
		tree[k].sum=w[pre[l]];
		tree[k].lz=0;
		return ;
	}
	int mid=(l+r)>>1;
	build(ls,l,mid);
	build(rs,mid+1,r);
	tree[k].sum=(tree[ls].sum+tree[rs].sum)%mod;
}

void push_down(int k)
{
	if(!tree[k].lz) return ;
	(tree[ls].lz+=tree[k].lz)%=mod;
	(tree[ls].sum+=tree[ls].len*tree[k].lz)%=mod;
	(tree[rs].lz+=tree[k].lz)%=mod;
	(tree[rs].sum+=tree[rs].len*tree[k].lz)%=mod;
	tree[k].lz=0;
}

int query(int k,int l,int r,int L,int R)
{
	int res=0;
	if(l>=L && r<=R) return tree[k].sum;
	push_down(k);
	int mid=(l+r)>>1;
	if(L<=mid)
	(res+=query(ls,l,mid,L,R))%=mod;
	if(R>mid)
	(res+=query(rs,mid+1,r,L,R))%=mod;
	return res;
}

void update(int k,int l,int r,int L,int R,int v)
{
	if(l>=L&&r<=R)
	{
		(tree[k].lz+=v)%=mod;
		(tree[k].sum+=v*tree[k].len)%=mod;
		return ;
	}
	push_down(k);
	int mid=(l+r)>>1;
	if(L<=mid) update(ls,l,mid,L,R,v);
	if(R>mid) update(rs,mid+1,r,L,R,v);
	tree[k].sum=(tree[ls].sum+tree[rs].sum)%mod;
}

int find(int x,int y)
{
	int ans=0;
	int top1=top[x],top2=top[y]; //取链顶 
	while(top1!=top2)//不在同一条链上	 
	{
		if(deep[top1]<deep[top2])//保证top1的深度更深 
		{
			swap(top1,top2);
			swap(x,y);
		}
		(ans+=query(1,1,n,dfn[top1],dfn[x]))%=mod;//求区间内所有节点值的和 
		x=fa[top1],top1=top[x];//往上继续搜 
	}
	if(deep[x]>deep[y]) swap(x,y);
	(ans+=query(1,1,n,dfn[x],dfn[y]))%=mod;//求区间和 
	return ans;//返回 
}

void change(int x,int y,int v)//同上 
{
	int top1=top[x],top2=top[y];
	while(top1!=top2)
	{
		if(deep[top1]<deep[top2])
		{
			swap(top1,top2);
			swap(x,y);
		}
		update(1,1,n,dfn[top1],dfn[x],v);
		x=fa[top1],top1=top[x];
	} 
	if(deep[x]>deep[y]) swap(x,y);
	update(1,1,n,dfn[x],dfn[y],v);
}

signed main()
{
	n=read(),m=read(),r=read(),mod=read();//输入节点个数,操作次数,根节点序号和模数 
	for(reg int i=1;i<=n;i++) w[i]=read();//各节点的初值 
	for(reg int i=1;i<n;i++)
	{
		int u,v;
		scanf("%d%d",&u,&v);//u,v之间连边 
		add(u,v);add(v,u);
	}
	dfs1(r,0);//得到重儿子编号,父节点,节点深度,子树大小 
	dfs2(r,r);//处理dfs序,每条链,他的顶端 
	build(1,1,n);//建树 
	for(reg int i=1;i<=m;i++)//m次操作 
	{
		int op=read();
		if(op==1)//从 X 到 Y 的最短路径上所有节点都加上x 
		{
			int x=read(),y=read(),z=read();
			change(x,y,z);
		}
		else if(op==2)//求 X 到 Y 的最短路径上所有节点的值的和 
		{
			int x,y;
			x=read(),y=read();
			printf("%lld\n",find(x,y));
		}
		else if(op==3)//以 x 为跟的所有子树的节点值都加上z 
		{
			int x,z;
			x=read(),z=read();
			update(1,1,n,dfn[x],dfn[x]+siz[x]-1,z);
		}
		if(op==4)//求以 x 为根节点的所有子树的值的和 
		{
			int x=read();
			printf("%lld\n",query(1,1,n,dfn[x],dfn[x]+siz[x]-1));
		}
	}
	return 0;
}
posted @ 2022-07-25 19:27  Always_maxx  阅读(44)  评论(1编辑  收藏  举报