树链剖分
bzoj1036
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作:
I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
第一遍dfs求出树每个结点的深度dep[id],其为根的子树大小size[id]
以及祖先的信息fa[id]表示id的祖先
第二遍dfs
根节点为起点,向下拓展构建重链,选择最大的一个子树的根继承当前重链,其余节点,都以该节点为起点向下重新拉一条重链。
给每个结点分配一个位置编号,每条重链就相当于一段区间,用数据结构去维护。//pos[id]
把所有的重链首尾相接,放到同一个数据结构上,然后维护这一个整体即可。
修改操作:
1、单独修改一个点的权值
根据其编号直接在数据结构中修改就行了。
2、修改点u和点v的路径上的权值
(1)若u和v在同一条重链上
直接用数据结构修改pos[u]至pos[v]间的值。
(2)若u和v不在同一条重链上
一边进行修改,一边将u和v往同一条重链上靠,然后就变成了情况(1)。
查询操作
1.查询操作的分析过程同修改操作
#include <bits/stdc++.h>
#define N 30005
#define M 60005
using namespace std;
int n,q,cnt,sz;//sz 线段树_number
int v[N],dep[N],size[N],head[N],fa[N];
int pos[N],bl[N];
//pos[i] 线段树_number
//bl[i] top[i]
struct node{
int to,nxt;
}e[M];
void insert(int u,int v){
e[++cnt].to=v;e[cnt].nxt=head[u];head[u]=cnt;
}
struct Tnd{
int l,r,mx,sum;
}T[N<<2];
void dfs1(int id){
size[id]=1;
for(int i=head[id];i;i=e[i].nxt){
if(e[i].to == fa[id]) continue;
dep[e[i].to] = dep[id] + 1;
fa[e[i].to] = id;
dfs1(e[i].to);
size[id] += size[e[i].to];
}
}
void dfs2(int id,int chain){
int k=0;sz++;
pos[id] = sz;
bl[id] = chain;
for(int i=head[id];i;i=e[i].nxt)
if(dep[e[i].to] > dep[id] && size[e[i].to] > size[k])
k=e[i].to;
if(k==0)
return;
dfs2(k,chain);
for(int i=head[id];i;i=e[i].nxt)
if(dep[e[i].to] > dep[id] && e[i].to!=k)
dfs2(e[i].to,e[i].to);
}
void build(int tp,int l,int r){
T[tp].l=l;T[tp].r=r;
if(l==r) return;
int mid = (l+r)>>1;
build(tp<<1,l,mid);
build(tp<<1|1,mid+1,r);
}
void change(int tp,int p,int delta){
int l=T[tp].l,r=T[tp].r,mid=(l+r)>>1;
if(l==r){
T[tp].sum = T[tp].mx = delta;
return;
}
if(p<=mid)
change(tp<<1,p,delta);
else
change(tp<<1|1,p,delta);
T[tp].sum = T[tp<<1].sum + T[tp<<1|1].sum;
T[tp].mx = max(T[tp<<1].mx,T[tp<<1|1].mx);
}
int querymx(int tp,int x,int y){
int l = T[tp].l,r = T[tp].r,mid=(l+r)>>1;
if(l==x && r==y)
return T[tp].mx;
else if(y<=mid) return querymx(tp<<1,x,y);
else if(x>mid) return querymx(tp<<1|1,x,y);
else return max(querymx(tp<<1,x,mid),querymx(tp<<1|1,mid+1,y));
}
int querysum(int tp,int x,int y){
int l = T[tp].l,r = T[tp].r,mid=(l+r)>>1;
if(l==x && r==y)
return T[tp].sum;
else if(y<=mid) return querysum(tp<<1,x,y);
else if(x>mid) return querysum(tp<<1|1,x,y);
else return querysum(tp<<1,x,mid)+querysum(tp<<1|1,mid+1,y);
}
int solvemx(int x,int y){
int mx=-INT_MAX;
while(bl[x]!=bl[y]){
if(dep[bl[x]]<dep[bl[y]])
swap(x,y);
mx=max(mx,querymx(1,pos[bl[x]],pos[x]));
x=fa[bl[x]];
}
if(pos[x]>pos[y])
swap(x,y);
mx=max(mx,querymx(1,pos[x],pos[y]));
return mx;
}
int solvesum(int x,int y){
int sum=0;
while(bl[x]!=bl[y]){
if(dep[bl[x]]<dep[bl[y]])
swap(x,y);
sum+=querysum(1,pos[bl[x]],pos[x]);
x=fa[bl[x]];
}
if(pos[x]>pos[y])
swap(x,y);
sum+=querysum(1,pos[x],pos[y]);
return sum;
}
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
insert(x,y);insert(y,x);
}
for(int i=1;i<=n;i++)
scanf("%d",&v[i]);
dfs1(1);
dfs2(1,1);
build(1,1,n);
for(int i=1;i<=n;i++)
change(1,pos[i],v[i]);
scanf("%d",&q);
char cho[20];
int x,y;
for(int i=1;i<=q;i++){
scanf("%s%d%d",cho,&x,&y);
if(cho[0]=='C'){
v[x]=y;
change(1,pos[x],y);
} else {
if(cho[1]=='M'){
printf("%d\n",solvemx(x,y));
} else {
printf("%d\n",solvesum(x,y));
}
}
}
return 0;
}

浙公网安备 33010602011771号