树链剖分

将一颗树分为若干条链,再用线段树维护,修改,查询

例题:洛谷[树链剖分(模板)]

dep[x]:x点的深度

fa[x]:x点的父亲节点

size[x]:以x点为根节点的子树节点个数(包括本身)

id[x]:x点在线段树中的新编号

now[x]:照应在线段树中id[x]的原值

top[x]:x点所在重链的起始位置

son[x]:x点的重儿子

重链:起始位置为根节点或者是轻儿子的链,例外的,每个轻儿子单独成一条重链

重儿子,轻儿子:每个根节点都有重儿子和轻儿子,重儿子为子节点size[u]最多的节点,除了重儿子以外的儿子都是轻儿子

求size[],dep[],fa[],son[]很简单,dfs*1

inline void dfs(int x,int father,int deep)
{
	dep[x]=deep;
	fa[x]=father;
	size[x]=1;
	int fond=-1;
	for(re i=h[x];i;i=w[i].up)
	{
		re v=w[i].zd;
		if(v==father)
		{
			continue;
		}
		dfs(v,x,deep+1);
		size[x]+=size[v];
		if(fond<size[v])
		{
			son[x]=v;
			fond=size[v];
		}
	}
}
dfs(r,0,1)

求 top[],now[],id[]也很简单,dfs*2

inline void dfs1(int x,int topp)
{
	top[x]=topp;
	id[x]=++cnt;
	now[cnt]=vis[x];
	if(!son[x])
	{
		return ;
	}
	dfs1(son[x],topp);
	for(re i=h[x];i;i=w[i].up)
	{
		re v=w[i].zd;
		if(v==fa[x]||v==son[x])
		continue;
		dfs1(v,v);
	}
}

那问题就剩下怎么用了,等下啊,我去隔壁题解盗一张图,我画的灵魂,怕看不懂 

 从要处理的节点一次往上跳,更新线段树中的值,跳的过程有点像倍增LCA,先将查询的两个点跳到同一条链上,一路跳,一路更新,最后再将链的内容更新

inline void updan(int x,int y,int k)
{
	k%=mod;
	while(top[x]!=top[y])
	{
		if(dep[top[y]]>dep[top[x]])//选深度大的跳
		swap(x,y);
		ad(1,id[top[x]],id[x],k);//将当前x在的重链更新
		x=fa[top[x]];//重链的更新包括头节点,所以将x跳到头节点的父亲
	}
    //x,y已经在同一条链上了
	if(dep[x]>dep[y])
	swap(y,x);
	ad(1,id[x],id[y],k);
}

查询也是同理

inline int qge(int x,int y)
{
	int ans=0;
	while(top[x]!=top[y])
	{
		if(dep[top[y]]>dep[top[x]])
		swap(x,y);
		ans+=solve(1,id[top[x]],id[x]);
		ans%=mod;
		x=fa[top[x]];
	} 
	if(dep[x]>dep[y])
	swap(y,x);
	ans+=solve(1,id[x],id[y]);
	return ans%mod;
}

因为先遍历重节点,(等下啊,我去盗张图),如下图,蓝色的数字是dfs1的遍历顺序,可以发现非常优秀的一些性质

因为顺序是先重再轻,所以每一条重链的新编号是连续的

因为是dfs,所以每一个子树的新编号也是连续的

所以修改和查询子树的话非常的简单,线段树模板更新即可

AC代码:

#include<bits/stdc++.h>
using namespace std;
#define re register int
#define ll long long
inline void FRE()
{
	freopen(".in","r",stdin);
	freopen(".out","w",stdout);
}
inline void FCL()
{
	fclose(stdin);
	fclose(stdout);
}
priority_queue<int>mp;
vector<int>v;
const int N=1e5+5;
int mod=1e9+7;
const int inf=0x3fffffff;
inline ll read()
{
	ll s=0,f=1;
	char a=getchar();
	while(a<'0'||a>'9')
	{
		if(a=='-')
		f=-1;
		a=getchar();
	}
	while(a>='0'&&a<='9')
	{
		s=(s<<3)+(s<<1)+a-48;
		a=getchar();
	}
	return s*f;
}
inline void output(ll x)
{
	ll y=10,len=1;
	while(y<=x)
	{
		y*=10;
		len++;
	}
	while(len--)
	{
		y/=10;
		putchar(x/y+48);
		x%=y;
	}
	output('\n');
}
int n,mm,r,p,vis[N],h[N],tp;
struct fuck
{
	int up;
	int zd;
}w[N*2];
inline void add(int ne,int ed)
{
	w[++tp].up=h[ne];
	w[tp].zd=ed;
	h[ne]=tp;
}
int dep[N],size[N],son[N],top[N],fa[N],id[N],now[N],cnt;
struct fuck2
{
	int flag;
	int l;
	int r;
	int vis;
}m[4*N];
inline void build(int x,int y,int z)
{
	m[x].l=y;
	m[x].r=z;
	m[x].flag=0;
	if(y==z)
	{
		m[x].vis=now[z];
		m[x].vis%=mod;
		return;
	}
	re mid=(y+z)>>1;
	build((x<<1),y,mid);
	build((x<<1)|1,mid+1,z);
	m[x].vis=m[(x<<1)].vis+m[(x<<1)|1].vis;
	m[x].vis%=mod;
}
inline void bedown(int x)
{
	if(m[x].flag)
	{
		m[x<<1].flag+=m[x].flag;
		m[(x<<1)|1].flag+=m[x].flag;
		m[x<<1].vis+=m[x].flag*(m[x<<1].r-m[x<<1].l+1);
		m[(x<<1)|1].vis+=m[x].flag*(m[(x<<1)|1].r-m[(x<<1)|1].l+1);
		m[x].flag=0;
		m[(x<<1)].vis%=mod;
		m[(x<<1)|1].vis%=mod;
	}
}
inline void ad(int x,int y,int z,int k)
{
	if(m[x].l>=y&&m[x].r<=z)
	{
		m[x].flag+=k;
		m[x].flag%=mod;
		m[x].vis+=k*(m[x].r-m[x].l+1);
		m[x].vis%=mod;
		return; 
	}
	bedown(x);
	re mid=(m[x].l+m[x].r)>>1;
	if(y<=mid)
	ad(x<<1,y,z,k);
	if(z>mid)
	ad((x<<1)|1,y,z,k);
	m[x].vis=m[x<<1].vis+m[(x<<1)|1].vis;
	m[x].vis%=mod;
}
inline int solve(int x,int y,int z)
{
	if(m[x].l>=y&&m[x].r<=z)
	{
		return m[x].vis;
	}
	bedown(x);
	re mid=(m[x].l+m[x].r)>>1;
	int ans=0;
	if(y<=mid)
	{
		ans+=solve(x<<1,y,z);
		ans%=mod;
	}
	if(z>mid)
	{
		ans+=solve((x<<1)|1,y,z);
		ans%=mod;
	}
	return ans%mod;
}
inline void dfs(int x,int father,int deep)
{
	dep[x]=deep;
	fa[x]=father;
	size[x]=1;
	int fond=-1;
	for(re i=h[x];i;i=w[i].up)
	{
		re v=w[i].zd;
		if(v==father)
		{
			continue;
		}
		dfs(v,x,deep+1);
		size[x]+=size[v];
		if(fond<size[v])
		{
			son[x]=v;
			fond=size[v];
		}
	}
}
inline void dfs1(int x,int topp)
{
	top[x]=topp;
	id[x]=++cnt;
	now[cnt]=vis[x];
	if(!son[x])
	{
		return ;
	}
	dfs1(son[x],topp);
	for(re i=h[x];i;i=w[i].up)
	{
		re v=w[i].zd;
		if(v==fa[x]||v==son[x])
		continue;
		dfs1(v,v);
	}
}
inline void updan(int x,int y,int k)
{
	k%=mod;
	while(top[x]!=top[y])
	{
		if(dep[top[y]]>dep[top[x]])
		swap(x,y);
		ad(1,id[top[x]],id[x],k);
		x=fa[top[x]];
	} 
	if(dep[x]>dep[y])
	swap(y,x);
	ad(1,id[x],id[y],k);
}
inline int qge(int x,int y)
{
	int ans=0;
	while(top[x]!=top[y])
	{
		if(dep[top[y]]>dep[top[x]])
		swap(x,y);
		ans+=solve(1,id[top[x]],id[x]);
		ans%=mod;
		x=fa[top[x]];
	} 
	if(dep[x]>dep[y])
	swap(y,x);
	ans+=solve(1,id[x],id[y]);
	return ans%mod;
}
inline void work(int x,int y)
{
	ad(1,id[x],id[x]+size[x]-1,y);
}
inline int doit(int x)
{
	return solve(1,id[x],id[x]+size[x]-1);
}
int main()
{
	//FRE();
	n=read();
	mm=read();
	r=read();
	p=read();
	mod=p;
	for(re i=1;i<=n;i++)
	{
		vis[i]=read();
	}
	for(re i=1;i<=n-1;i++)
	{
		re x,y;
		x=read();
		y=read();
		add(x,y);
		add(y,x);
	}
	dfs(r,0,1);
	dfs1(r,r);
	build(1,1,n);
	
	while(mm--)
	{
		re x;
		x=read();
		if(x==1)
		{
			re y,z,k;
			y=read();
			z=read();
			k=read();
			updan(y,z,k);
		}
		else
		if(x==2)
		{
			re y,z;
			y=read();
			z=read();
			printf("%d\n",qge(y,z));
		}
		else
		if(x==3)
		{
			re y,z;
			y=read();
			z=read();
			work(y,z);
		}
		else
		{
			re y;
			y=read();
			printf("%d\n",doit(y));
		}
	}
	//FCL();
	return 0;
}

一篇更好的题解

posted @ 2018-10-02 20:04  Tarjan_fjz  阅读(137)  评论(0编辑  收藏  举报