[洛谷P3384]【模板】树链剖分

题目大意:已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z

操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z

操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

解题思路:树链剖分。

剖完进行dfs遍历,并记录每个节点的dfs序(优先遍历重链)。

可以发现任何一条重链的dfs序都是连续的,并且任何一棵子树中所有节点的dfs序也是连续的。

我们用线段树来维护每个dfs序对应的节点的信息。

对于操作1和2,让两个节点往链顶跳,每条一次在线段树中更新或查询链顶到原来节点的dfs序的信息。

对于操作3和4,由于一棵子树中所有节点dfs序有序,直接修改或查询即可。若一棵子树的根节点的dfs序是x,子树大小是sz,那子树最大的一个dfs序是x+sz-1。

由于树剖重链和轻链数量都是log级的,加上线段树时间复杂度,总时间复杂度$O(m\log^2 n)$。

C++ Code:

#include<cstdio>
#include<cctype>
#include<cstring>
#define ll long long
#define N 120500
int n,m,rt,p,a[N],head[N],cnt,sz[N],fa[N],dep[N],son[N],dfn[N],idx;
int aa[N],L,R,c,top[N];
ll ans;
struct SegmentTreeNode{
	ll sum,add;
}d[N<<2];
struct edge{
	int to,nxt;
}e[N<<1];
inline int readint(){
	char c=getchar();
	for(;!isdigit(c);c=getchar());
	int d=0;
	for(;isdigit(c);c=getchar())
	d=(d<<3)+(d<<1)+(c^'0');
	return d;
}
void dfs(int now){
	sz[now]=1;
	for(int i=head[now];i;i=e[i].nxt)
	if(!dep[e[i].to]){
		dep[e[i].to]=dep[now]+1;
		fa[e[i].to]=now;
		dfs(e[i].to);
		sz[now]+=sz[e[i].to];
		if(son[now]==0||sz[e[i].to]>sz[son[now]])son[now]=e[i].to;
	}
}
void dfs2(int now){
	dfn[now]=++idx;
	if(son[now])top[son[now]]=top[now],dfs2(son[now]);
	for(int i=head[now];i;i=e[i].nxt)
	if(dep[now]<dep[e[i].to]&&e[i].to!=son[now])
	dfs2(top[e[i].to]=e[i].to);
}
inline void update(int l,int o){
	int lft=o<<1;
	int rgt=lft|1;
	d[lft].add=(d[lft].add+d[o].add)%p;
	d[rgt].add=(d[rgt].add+d[o].add)%p;
	d[lft].sum=(d[lft].sum+d[o].add*((l+1)>>1))%p;
	d[rgt].sum=(d[rgt].sum+d[o].add*(l>>1))%p;
	d[o].add=0;
}
void build(int l,int r,int o){
	if(l==r){
		d[o]=(SegmentTreeNode){aa[l],0};
		return;
	}
	int mid=(l+r)>>1;
	build(l,mid,o<<1);
	build(mid+1,r,o<<1|1);
	d[o].sum=(d[o<<1].sum+d[o<<1|1].sum)%p;
	d[o].add=0;
}
void add_T(int l,int r,int o){
	if(L<=l&&r<=R){
		d[o].add=(d[o].add+c)%p;
		d[o].sum=(d[o].sum+c*(r-l+1))%p;
		return;
	}
	int mid=(l+r)>>1;
	update(r-l+1,o);
	if(L<=mid)add_T(l,mid,o<<1);
	if(mid<R)add_T(mid+1,r,o<<1|1);
	d[o].sum=(d[o<<1].sum+d[o<<1|1].sum)%p;
}
void que_T(int l,int r,int o){
	if(L<=l&&r<=R){
		ans=(ans+d[o].sum)%p;
		return;
	}
	int mid=(l+r)>>1;
	update(r-l+1,o);
	if(L<=mid)que_T(l,mid,o<<1);
	if(mid<R)que_T(mid+1,r,o<<1|1);
}
void add_1(int x,int y){
	for(;top[x]!=top[y];)
	if(dep[top[x]]>=dep[top[y]]){
		L=dfn[top[x]],R=dfn[x];
		add_T(1,n,1);
		x=fa[top[x]];
	}else{
		L=dfn[top[y]],R=dfn[y];
		add_T(1,n,1);
		y=fa[top[y]];
	}
	if(dep[x]<=dep[y]){
		L=dfn[x],R=dfn[y];
		add_T(1,n,1);
	}else{
		L=dfn[y],R=dfn[x];
		add_T(1,n,1);
	}
}
void que_1(int x,int y){
	for(;top[x]!=top[y];)
	if(dep[top[x]]>=dep[top[y]]){
		L=dfn[top[x]],R=dfn[x];
		que_T(1,n,1);
		x=fa[top[x]];
	}else{
		L=dfn[top[y]],R=dfn[y];
		que_T(1,n,1);
		y=fa[top[y]];
	}
	if(dep[x]<=dep[y]){
		L=dfn[x],R=dfn[y];
		que_T(1,n,1);
	}else{
		L=dfn[y],R=dfn[x];
		que_T(1,n,1);
	}
}
int main(){
	memset(dep,0,sizeof dep);
	memset(head,0,sizeof head);
	memset(son,0,sizeof son);
	cnt=idx=0;
	n=readint(),m=readint(),rt=readint(),p=readint();
	for(int i=1;i<=n;++i)a[i]=readint()%p;
	for(int i=1;i<n;++i){
		int u=readint(),v=readint();
		e[++cnt]=(edge){v,head[u]};
		head[u]=cnt;
		e[++cnt]=(edge){u,head[v]};
		head[v]=cnt;
	}
	dep[top[rt]=rt]=1;
	fa[rt]=rt;
	dfs(rt);
	dfs2(rt);
	for(int i=1;i<=n;++i)aa[dfn[i]]=a[i];
	build(1,n,1);
	while(m--){
		int x=readint(),l,r;
		switch(x){
			case 1:
				l=readint(),r=readint(),c=readint();
				add_1(l,r);
				break;
			case 2:
				ans=0;
				l=readint(),r=readint();
				que_1(l,r);
				printf("%d\n",(int)(ans%p));
				break;
			case 3:
				L=readint(),c=readint();
				R=dfn[L]+sz[L]-1;
				L=dfn[L];
				add_T(1,n,1);
				break;
			case 4:
				ans=0;
				L=readint();
				R=dfn[L]+sz[L]-1;
				L=dfn[L];
				que_T(1,n,1);
				printf("%d\n",(int)(ans%p));
				break;
		}
	}
	return 0;
}

 

posted @ 2017-11-03 15:11  Mrsrz  阅读(118)  评论(0编辑  收藏  举报