树链剖分就是把树拆成一系列链,然后用数据结构对链进行维护。
树链剖分主要变量:
dep[x]表示x节点的深度,size[x]表示以x为根节点的树的大小,son[x]表示x的重儿子(重儿子即x的所有儿子中size最大的儿子),
fa[x]表示x的父亲,top[x]表示x所属重链的头部。
首先,dep,size,son,fa可以简单用一个dfs解决
void dfs(int x)
{
siz[x]=1;son[x]=0;siz[0]=0;
for(int j=last[x];j;j=e[j].next)
{
int y=e[j].to;
if(y!=fa[x])
{
fa[y]=x;
dep[y]=dep[x]+1;
dfs(y);
if(siz[y]>siz[son[x]])son[x]=y;
siz[x]+=siz[y];
}
}
}
对于top,如果x为fa[x]的重儿子,那么top[x]=top[fa[x]],否则top[x]=x
void dfs_tree(int x,int tp)
{
w[x]=++z;top[x]=tp;//w[x]为x节点对应的线段树中的叶节点
if(son[x]!=0)dfs_tree(son[x],tp);else return;
for(int j=last[x];j;j=e[j].next)
{
int y=e[j].to;
if(y!=son[x]&&y!=fa[x])dfs_tree(y,y);
}
}
然后我们可以借助一些数据结构维护这些链,一般用线段树
显然一条重链的点,它们的w会构成一段区间[l,r]
所以,直接添加元素
for(int i=1;i<=n;i++)change(1,w[i],a[i]);//change为普通线段树更改操作
接下来,求值操作,求x到y的树上路径中的最大值
int solvemx(int x,int y)
{
int mx=-1e9;
while(top[x]!=top[y])//让它们不停地沿着重链向上爬。
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
mx=max(mx,querymx(1,w[top[x]],w[x]));//查找x所属重链的max
x=fa[top[x]];
}
if(w[x]>w[y])swap(x,y);
mx=max(mx,querymx(1,w[x],w[y]));
return mx;
}
#include<bits/stdc++.h>
#define maxn 300005
using namespace std;
int siz[maxn],dep[maxn],top[maxn],fa[maxn],son[maxn],a[maxn];
int w[maxn],n,m,x,y,last[maxn],cnt,z;
struct edge{
int to,next;
}e[maxn];
struct tree{
int sum,mx,l,r;
}tr[maxn];
void insert(int x,int y){
e[++cnt].to=y;e[cnt].next=last[x];last[x]=cnt;
}
void dfs(int x)
{
siz[x]=1;son[x]=0;siz[0]=0;
for(int j=last[x];j;j=e[j].next)
{
int y=e[j].to;
if(y!=fa[x])
{
fa[y]=x;
dep[y]=dep[x]+1;
dfs(y);
if(siz[y]>siz[son[x]])son[x]=y;
siz[x]+=siz[y];
}
}
}
void dfs_tree(int x,int tp)
{
w[x]=++z;top[x]=tp;
if(son[x]!=0)dfs_tree(son[x],tp);else return;
for(int j=last[x];j;j=e[j].next)
{
int y=e[j].to;
if(y!=son[x]&&y!=fa[x])dfs_tree(y,y);
}
}
void build(int x,int l,int r)
{
tr[x].l=l;tr[x].r=r;
if(l==r)return;
int mid=(l+r)>>1;
build(x*2,l,mid);
build(x*2+1,mid+1,r);
}
void change(int now,int x,int y)
{
int l=tr[now].l,r=tr[now].r,mid=(l+r)>>1;
if(l==r){
tr[now].mx=tr[now].sum=y;
return;
}
if(x<=mid)change(now*2,x,y);else change(now*2+1,x,y);
tr[now].mx=max(tr[now*2].mx,tr[now*2+1].mx);
tr[now].sum=tr[now*2].sum+tr[now*2+1].sum;
}
int querymx(int now,int x,int y)
{
int l=tr[now].l,r=tr[now].r,mid=(l+r)>>1;
if(x<=l&&y>=r)return tr[now].mx;
if(x>r||y<l)return -1e9;
return max(querymx(now*2,x,y),querymx(now*2+1,x,y));
}
int querysum(int now,int x,int y)
{
int l=tr[now].l,r=tr[now].r,mid=(l+r)>>1;
if(x<=l&&y>=r)return tr[now].sum;
if(x>r||y<l)return 0;
return querysum(now*2,x,y)+querysum(now*2+1,x,y);
}
int solvemx(int x,int y)
{
int mx=-1e9;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
mx=max(mx,querymx(1,w[top[x]],w[x]));
x=fa[top[x]];
}
if(w[x]>w[y])swap(x,y);
mx=max(mx,querymx(1,w[x],w[y]));
return mx;
}
int solvesum(int x,int y)
{
int sum=0,bo=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
sum+=querysum(1,w[top[x]],w[x]);
if(top[x]==1)bo=1;
x=fa[top[x]];
}
if(w[x]>w[y])swap(x,y);
if(!bo)sum+=querysum(1,w[x],w[y]);//注意这一判断
return sum;
}
void solve()
{
char c[10];int x,y;
scanf("%d",&m);
for(int i=1;i<=m;i++)
{
scanf("%s%d%d",&c,&x,&y);
if(c[0]=='C')a[x]=y,change(1,w[x],y);
else
{
if(c[1]=='M')printf("%d\n",solvemx(x,y));
else printf("%d\n",solvesum(x,y));
}
}
}
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
insert(x,y);insert(y,x);
}
for(int i=1;i<=n;i++)scanf("%d",&a[i]);
fa[1]=1;
dfs(1);
dfs_tree(1,1);
build(1,1,n);
for(int i=1;i<=n;i++)change(1,w[i],a[i]);
// for(int i=1;i<=n;i++)printf("%d %d %d\n",w[i],top[i],fa[i]);
solve();
//printf("%d %d\n",querysum(1,1,3),querysum(1,4,4));
return 0;
}
ac代码
浙公网安备 33010602011771号