树链剖分就是把树拆成一系列链,然后用数据结构对链进行维护。

树链剖分主要变量:

dep[x]表示x节点的深度,size[x]表示以x为根节点的树的大小,son[x]表示x的重儿子(重儿子即x的所有儿子中size最大的儿子),

fa[x]表示x的父亲,top[x]表示x所属重链的头部。

首先,dep,size,son,fa可以简单用一个dfs解决

void dfs(int x)
{
	siz[x]=1;son[x]=0;siz[0]=0;
	for(int j=last[x];j;j=e[j].next)
	{
		int y=e[j].to;
		if(y!=fa[x])
		{
			fa[y]=x;
			dep[y]=dep[x]+1;
			dfs(y);
			if(siz[y]>siz[son[x]])son[x]=y;
			siz[x]+=siz[y];
			
		}
	}
}

 对于top,如果x为fa[x]的重儿子,那么top[x]=top[fa[x]],否则top[x]=x

void dfs_tree(int x,int tp)
{
	w[x]=++z;top[x]=tp;//w[x]为x节点对应的线段树中的叶节点
	if(son[x]!=0)dfs_tree(son[x],tp);else return;
	for(int j=last[x];j;j=e[j].next)
	{
		int y=e[j].to;
		if(y!=son[x]&&y!=fa[x])dfs_tree(y,y);
	}
}

然后我们可以借助一些数据结构维护这些链,一般用线段树

显然一条重链的点,它们的w会构成一段区间[l,r]

所以,直接添加元素

	for(int i=1;i<=n;i++)change(1,w[i],a[i]);//change为普通线段树更改操作

接下来,求值操作,求x到y的树上路径中的最大值

int solvemx(int x,int y)
{
	int mx=-1e9;
	while(top[x]!=top[y])//让它们不停地沿着重链向上爬。
	{
		if(dep[top[x]]<dep[top[y]])swap(x,y);
		mx=max(mx,querymx(1,w[top[x]],w[x]));//查找x所属重链的max
		x=fa[top[x]];
 	}
		if(w[x]>w[y])swap(x,y);
		mx=max(mx,querymx(1,w[x],w[y]));
	return mx;
}

模板题

#include<bits/stdc++.h> 
#define maxn 300005
using namespace std;
int siz[maxn],dep[maxn],top[maxn],fa[maxn],son[maxn],a[maxn];
int w[maxn],n,m,x,y,last[maxn],cnt,z;
struct edge{
	int to,next;
}e[maxn];
struct tree{
	int sum,mx,l,r;
}tr[maxn];
void insert(int x,int y){
	e[++cnt].to=y;e[cnt].next=last[x];last[x]=cnt;
}
void dfs(int x)
{
	siz[x]=1;son[x]=0;siz[0]=0;
	for(int j=last[x];j;j=e[j].next)
	{
		int y=e[j].to;
		if(y!=fa[x])
		{
			fa[y]=x;
			dep[y]=dep[x]+1;
			dfs(y);
			if(siz[y]>siz[son[x]])son[x]=y;
			siz[x]+=siz[y];
			
		}
	}
}
void dfs_tree(int x,int tp)
{
	w[x]=++z;top[x]=tp;
	if(son[x]!=0)dfs_tree(son[x],tp);else return;
	for(int j=last[x];j;j=e[j].next)
	{
		int y=e[j].to;
		if(y!=son[x]&&y!=fa[x])dfs_tree(y,y);
	}
}
void build(int x,int l,int r)
{
	tr[x].l=l;tr[x].r=r;
	if(l==r)return;
	int mid=(l+r)>>1;
	build(x*2,l,mid);
	build(x*2+1,mid+1,r);
}
void change(int now,int x,int y)
{
	int l=tr[now].l,r=tr[now].r,mid=(l+r)>>1;
	if(l==r){
		tr[now].mx=tr[now].sum=y;
		return;
	}
	if(x<=mid)change(now*2,x,y);else change(now*2+1,x,y);
	tr[now].mx=max(tr[now*2].mx,tr[now*2+1].mx);
	tr[now].sum=tr[now*2].sum+tr[now*2+1].sum;
}
int querymx(int now,int x,int y)
{
	int l=tr[now].l,r=tr[now].r,mid=(l+r)>>1;
	if(x<=l&&y>=r)return tr[now].mx;
	if(x>r||y<l)return -1e9;
	return max(querymx(now*2,x,y),querymx(now*2+1,x,y));
}
int querysum(int now,int x,int y)
{
	int l=tr[now].l,r=tr[now].r,mid=(l+r)>>1;
	if(x<=l&&y>=r)return tr[now].sum;
	if(x>r||y<l)return 0;
	return querysum(now*2,x,y)+querysum(now*2+1,x,y);
}
int solvemx(int x,int y)
{
	int mx=-1e9;
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]])swap(x,y);
		mx=max(mx,querymx(1,w[top[x]],w[x]));
		x=fa[top[x]];
 	}
		if(w[x]>w[y])swap(x,y);
		mx=max(mx,querymx(1,w[x],w[y]));
	return mx;
}
int solvesum(int x,int y)
{
	int sum=0,bo=0;
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]])swap(x,y);
		sum+=querysum(1,w[top[x]],w[x]);
		if(top[x]==1)bo=1;
		x=fa[top[x]];
	}
		if(w[x]>w[y])swap(x,y);
		if(!bo)sum+=querysum(1,w[x],w[y]);//注意这一判断
	return sum;
}
void solve()
{
	char c[10];int x,y;
	scanf("%d",&m);
	for(int i=1;i<=m;i++)
	{
		scanf("%s%d%d",&c,&x,&y);
		if(c[0]=='C')a[x]=y,change(1,w[x],y);
		else
		{
			if(c[1]=='M')printf("%d\n",solvemx(x,y));
			else printf("%d\n",solvesum(x,y));
		}
	}
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<n;i++)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		insert(x,y);insert(y,x);
	}
	for(int i=1;i<=n;i++)scanf("%d",&a[i]);
	fa[1]=1;
	dfs(1);
	dfs_tree(1,1);
	build(1,1,n);
	for(int i=1;i<=n;i++)change(1,w[i],a[i]);
//	for(int i=1;i<=n;i++)printf("%d %d %d\n",w[i],top[i],fa[i]);
	solve();
	//printf("%d %d\n",querysum(1,1,3),querysum(1,4,4));
	return 0;
}

 

  

 

ac代码

 

posted on 2018-03-09 10:53  geniuschenjj  阅读(114)  评论(1编辑  收藏  举报