树链剖分
给出一棵 n 个节点的树,初始每个节点有一个点权,要求维护三种操作:
1 u w:将顶点 u 的权值修改为 w。
2 u v:询问从 u 到 v 的路径上所有顶点的权值和。
3 u v:询问从 u 到 v 的路径上最大的权值是多少。
代码:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define ll long long
#define il inline
#define db double
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
using namespace std;
il int gi()
{
int x=0,y=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-')
y=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=x*10+ch-'0';
ch=getchar();
}
return x*y;
}
il ll gl()
{
ll x=0,y=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-')
y=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=x*10+ch-'0';
ch=getchar();
}
return x*y;
}
ll point[100045];//权值
int size[100045];//子树大小
int top[100045];//该链顶
int fa[100045];//爸爸
int son[100045];//重儿子
int deep[100045];//节点深度
int num[100045],tot;//编号
int pos[100045];//编号对应的点
int head[200045],cnt;
struct edge
{
int next,to;
}e[200045];
il void add(int from,int to)
{
e[++cnt].next=head[from];
e[cnt].to=to;
head[from]=cnt;
}
//bool vis[100045];
void dfs1(int x)
{
int r=head[x];
size[x]=1;
while(r!=-1)
{
int now=e[r].to;
if(now!=fa[x])
{
deep[now]=deep[x]+1;
fa[now]=x;
dfs1(now);
size[x]+=size[now];
if(son[x]==-1||size[now]>size[son[x]])
son[x]=now;
}
r=e[r].next;
}
}
void dfs2(int x,int anc)
{
top[x]=anc;
num[x]=++tot;
pos[tot]=x;
if(son[x]==-1)
return;
dfs2(son[x],anc);//找重链
int r=head[x];
while(r!=-1)
{
int now=e[r].to;
if(now!=fa[x]&&now!=son[x])
dfs2(now,now);//轻链
r=e[r].next;
}
}
struct node
{
ll sum,maxn;
}c[1000045];
void build(int rt,int l,int r)
{
if(l==r)
{
c[rt].sum=point[pos[l]];
c[rt].maxn=point[pos[l]];
return;
}
if(l>r)
return;
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
c[rt].sum=c[rt<<1].sum+c[rt<<1|1].sum;
c[rt].maxn=max(c[rt<<1].maxn,c[rt<<1|1].maxn);
}
void update(int rt,int l,int r,int pos,ll NUM)
{
if(l==r)
{
c[rt].sum=NUM;
c[rt].maxn=NUM;
return;
}
if(l>r)return;
int mid=(l+r)>>1;
if(pos<=mid)
update(rt<<1,l,mid,pos,NUM);
else
update(rt<<1|1,mid+1,r,pos,NUM);
c[rt].sum=c[rt<<1].sum+c[rt<<1|1].sum;
c[rt].maxn=max(c[rt<<1].maxn,c[rt<<1|1].maxn);
}
ll query(int rt,int l,int r,int L,int R)
{
//if(l>r)
//return 0;
if(L<=l&&R>=r)
return c[rt].sum;
if(L>r||R<l)return 0;
int mid=(l+r)/2;
ll sum=0;
if(L<=mid)
sum+=query(rt<<1,l,mid,L,R);
if(R>mid)
sum+=query(rt<<1|1,mid+1,r,L,R);
return sum;
}
ll queryy(int rt,int l,int r,int L,int R)
{
if(L<=l&&R>=r)
return c[rt].maxn;
int mid=(l+r)>>1;
if(L>r||R<l)return -2e9;
ll r1=-2e9,r2=-2e9;
if(L<=mid)
r1=queryy(rt<<1,l,mid,L,R);
if(R>mid)
r2=queryy(rt<<1|1,mid+1,r,L,R);
return max(r1,r2);
}
int main()
{
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
memset(head,-1,sizeof(head));
memset(son,-1,sizeof(son));
int n=gi();
for(int i=1;i<=n;i++)
point[i]=gl();
int x,y;
for(int i=1;i<n;i++)
{
x=gi(),y=gi();
add(x,y);
add(y,x);
}
deep[1]=1;
fa[1]=1;
dfs1(1);
dfs2(1,1);
build(1,1,n);
int m=gi();
int p;
for(int i=1;i<=m;i++)
{
//printf("c[1].sum=%d\n",c[1].sum);
p=gi();
if(p==1)
{
x=gi(),y=gi();
update(1,1,n,num[x],y);//点更新
}
if(p==2)
{
x=gi(),y=gi();
ll sum=0;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])//需要x链顶更深
swap(x,y);
sum+=query(1,1,n,num[top[x]],num[x]);//加上这一段区间和
x=fa[top[x]];//x跳到链顶的爸爸上
}
if(num[x]<num[y])
swap(x,y);
sum+=query(1,1,n,num[y],num[x]);//在加上最后一条边
printf("%lld\n",sum);
}
if(p==3)
{
x=gi(),y=gi();
ll ans=-2e9;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])
swap(x,y);
ans=max(ans,queryy(1,1,n,num[top[x]],num[x]));
x=fa[top[x]];
}
if(num[x]<num[y])
swap(x,y);
ans=max(ans,queryy(1,1,n,num[y],num[x]));
printf("%lld\n",ans);
}
}
return 0;
}
PEACE

浙公网安备 33010602011771号