树链剖分

题目

luogu3384

代码

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
#define N 1000005
#define ll long long
using namespace std;

int n,m,root,mod;
int w[N];

struct node
{
	int l,r;
	ll sum,d;
}T[N];

int num,b[N*10],nt[N*10],P[N*10];
void addedge(int x,int y)
{
	b[++num]=y;nt[num]=P[x];P[x]=num;
}

int deep[N],sz[N],son[N],fa[N];bool flag[N];
void dfs1(int k)
{
	flag[k]=1;sz[k]=1;
	for(int e=P[k];e;e=nt[e])
	{
		int kk=b[e];
		if(flag[kk]) continue;//除去父节点
		fa[kk]=k;
		deep[kk]=deep[k]+1;
		dfs1(kk); 
		if(sz[kk]>sz[son[k]]) son[k]=kk;//更新重儿子 
		sz[k]+=sz[kk];
	}
}

int t,top[N],id[N],dfn[N]; 
void dfs2(int k,int tp)
{
	top[k]=tp;dfn[++t]=k;id[k]=t;
	if(!son[k]) return;//若没有儿子,返回
	dfs2(son[k],tp);//先递归重儿子
	for(int e=P[k];e;e=nt[e])
	{
		int kk=b[e];
		if(kk!=son[k]&&kk!=fa[k]) dfs2(kk,kk);
	}
}

void pushup(int p)
{
    T[p].sum=(T[p<<1].sum+T[p<<1|1].sum)%mod;
}

void build(int p,int x,int y)
{
	T[p].l=x;T[p].r=y;
	if(x==y) {T[p].sum=w[dfn[x]];return;}
	int mid=(x+y)>>1;
	build(p<<1,x,mid);
	build(p<<1|1,mid+1,y);
	pushup(p); 
}

void pushdown(int p)
{
	T[p<<1].sum=(T[p<<1].sum+(T[p<<1].r-T[p<<1].l+1)*T[p].d%mod)%mod;
	T[p<<1|1].sum=(T[p<<1|1].sum+(T[p<<1|1].r-T[p<<1|1].l+1)*T[p].d%mod)%mod;
	T[p<<1].d=(T[p<<1].d+T[p].d)%mod;T[p<<1|1].d=(T[p<<1|1].d+T[p].d)%mod;
	T[p].d=0;
}

void update(int p,int x,int y,int v)
{
	int pl=T[p].l,pr=T[p].r;
	if(pl==x&&pr==y)
	{
		T[p].sum=(T[p].sum+(pr-pl+1)*v%mod)%mod;
		T[p].d=(T[p].d+v)%mod;
		return;
	}
	pushdown(p);
	int mid=(pl+pr)>>1;
	if(y<=mid) update(p<<1,x,y,v);
    else if(x>mid) update(p<<1|1,x,y,v);
    else
    {
        update(p<<1,x,mid,v);
        update(p<<1|1,mid+1,y,v);
    }
    pushup(p);
}

ll add(int x,int y,int v)
{
	while(top[x]!=top[y])//先将x,y翻到一条链上 
	{
		if(deep[top[x]]<deep[top[y]]) swap(x,y);
		update(1,id[top[x]],id[x],v);//更新x与其top之间的点 
		x=fa[top[x]];//并将x上翻 
	} 
	if(deep[x]>deep[y]) swap(x,y);
	update(1,id[x],id[y],v);
}

ll query(int p,int x,int y)
{
	int pl=T[p].l,pr=T[p].r;
	if(pl==x&&pr==y) return T[p].sum;
	pushdown(p);
    int mid=(pl+pr)>>1;
    if(y<=mid) return query(p<<1,x,y);
    else if(x>mid) return query(p<<1|1,x,y);
    else return (query(p<<1,x,mid)+query(p<<1|1,mid+1,y))%mod;
}

ll query_(int x,int y)
{
	ll cnt=0;
	while(top[x]!=top[y])
	{
		if(deep[top[x]]<deep[top[y]]) swap(x,y);
		cnt=(cnt+query(1,id[top[x]],id[x]))%mod;
		x=fa[top[x]];
	}
	if(deep[x]>deep[y]) swap(x,y);
	cnt=(cnt+query(1,id[x],id[y]))%mod;
	return cnt;
}

int main()
{
	scanf("%d%d%d%d",&n,&m,&root,&mod);
	for(int i=1;i<=n;i++) scanf("%d",&w[i]);
	for(int i=1;i<n;i++)
	{
		int x,y;scanf("%d%d",&x,&y);
		addedge(x,y);addedge(y,x);//先将树建起来 
	}
	dfs1(root);//第一次dfs得到deep,size,son和fa 
	dfs2(root,root);//第二次dfs得到top,id
	build(1,1,n);//建线段树 
	for(int i=1;i<=m;i++)
	{
		int opt,x,y,z;scanf("%d",&opt);
		if(opt==1)
		{
			scanf("%d%d%d",&x,&y,&z);
			add(x,y,z);
		}
		else if(opt==2)
		{
			scanf("%d%d",&x,&y);
			printf("%lld\n",query_(x,y));
		}
		else if(opt==3)
		{
			scanf("%d%d",&x,&z);
			update(1,id[x],id[x]+sz[x]-1,z);
		}
		else
		{
			scanf("%d",&x);
			printf("%lld\n",query(1,id[x],id[x]+sz[x]-1));
		}
	}
	return 0;
}
posted @ 2017-08-29 10:55  XYZinc  阅读(142)  评论(0)    收藏  举报