bzoj 1036: [ZJOI2008]树的统计Count——树链剖分

Description

  一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身

Input

  输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。 
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

Output

  对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

Sample Input

4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4

Sample Output

4
1
2
2
10
6
5
6
5
16

————————————————————————————

这题就是典型的树上路径取max 路径求和 单点修改了

算法没什么好说的 写了三种写法

1——树链剖分(线段树版)

#include<cstdio>
#include<cstring>
#include<algorithm>
using std::swap;
const int M=50007;
int read(){
    int ans=0,f=1,c=getchar();
    while(c<'0'||c>'9'){if(c=='-') f=-1; c=getchar();}
    while(c>='0'&&c<='9'){ans=ans*10+(c-'0'); c=getchar();}
    return ans*f;
}
int max(int x,int y){return x>y?x:y;}
char ch[5];
int n,m;
int first[M],cnt=1;
struct node{int to,next;}e[2*M];
void ins(int a,int b){e[++cnt]=(node){b,first[a]}; first[a]=cnt;}
void insert(int a,int b){ins(a,b); ins(b,a);}
int fa[M],sz[M],top[M],son[M],id[M],idp=1,dep[M];
void f1(int x){
    sz[x]=1;
    for(int i=first[x];i;i=e[i].next){
        int now=e[i].to;
        if(now==fa[x]) continue;
        fa[now]=x;  dep[now]=dep[x]+1;
        f1(now);  sz[x]+=sz[now];
        if(sz[now]>sz[son[x]]) son[x]=now;
    }
}
void f2(int x,int tp){
    top[x]=tp; id[x]=idp++;
    if(son[x]) f2(son[x],tp);
    for(int i=first[x];i;i=e[i].next){
        int now=e[i].to;
        if(now!=fa[x]&&now!=son[x]) f2(now,now);
    }
}
struct pos{int l,r,sum,mx;}tr[4*M];
void build(int x,int l,int r){
    tr[x].l=l; tr[x].r=r;
    if(l==r) return ;
    int mid=(l+r)>>1;
    build(x<<1,l,mid); build(x<<1^1,mid+1,r);
}
void up(int x){
    tr[x].sum=tr[x<<1].sum+tr[x<<1^1].sum;
    tr[x].mx=max(tr[x<<1].mx,tr[x<<1^1].mx);
}
void modify(int x,int s,int y){
    if(tr[x].l==tr[x].r) return void(tr[x].sum=tr[x].mx=y);
    int mid=(tr[x].l+tr[x].r)>>1;
    if(mid>=s) modify(x<<1,s,y);
    else modify(x<<1^1,s,y);
    up(x);
}
int push_max(int x,int L,int R){
    if(L<=tr[x].l&&tr[x].r<=R) return tr[x].mx;
    int mid=(tr[x].l+tr[x].r)>>1,ans=-1e8;
    if(L<=mid) ans=max(ans,push_max(x<<1,L,R));
    if(R>mid) ans=max(ans,push_max(x<<1^1,L,R));
    return ans;
}
int Qmax(int x,int y){
    int ans=-1e8;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        ans=max(ans,push_max(1,id[top[x]],id[x]));
        x=fa[top[x]];
        
    }
    if(id[x]>id[y]) swap(x,y);
    ans=max(ans,push_max(1,id[x],id[y]));
    return ans;
}
int push_sum(int x,int L,int R){
    if(L<=tr[x].l&&tr[x].r<=R) return tr[x].sum;
    int mid=(tr[x].l+tr[x].r)>>1,sum=0;
    if(L<=mid) sum+=push_sum(x<<1,L,R);
    if(R>mid)  sum+=push_sum(x<<1^1,L,R);
    return sum;
}
int Qsum(int x,int y){
    int sum=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        sum+=push_sum(1,id[top[x]],id[x]);
        x=fa[top[x]];
    }
    if(id[x]>id[y]) swap(x,y);
    sum+=push_sum(1,id[x],id[y]);
    return sum;
}
int main(){
    int x,y; 
    n=read(); for(int i=1;i<n;i++) x=read(),y=read(),insert(x,y);
    f1(1); f2(1,1); build(1,1,n);
    for(int i=1;i<=n;i++) x=read(),modify(1,id[i],x);
    m=read();
    for(int i=1;i<=m;i++){
        scanf("%s",ch); x=read(); y=read();
        if(ch[0]=='C') modify(1,id[x],y);
        else if(ch[1]=='M') printf("%d\n",Qmax(x,y));
        else printf("%d\n",Qsum(x,y));
    }
    return 0;
}
View Code

2——树链剖分(zkw线段树版)

#include<cstdio>
#include<cstring>
#include<algorithm>
using std::swap;
const int M=50007;
int read(){
    int ans=0,f=1,c=getchar();
    while(c<'0'||c>'9'){if(c=='-') f=-1; c=getchar();}
    while(c>='0'&&c<='9'){ans=ans*10+(c-'0'); c=getchar();}
    return ans*f;
}
int max(int x,int y){return x>y?x:y;}
char ch[5];
int n,m;
int first[M],cnt=1;
struct node{int to,next;}e[2*M];
void ins(int a,int b){e[++cnt]=(node){b,first[a]}; first[a]=cnt;}
void insert(int a,int b){ins(a,b); ins(b,a);}
int fa[M],sz[M],top[M],son[M],id[M],idp=1,dep[M];
void f1(int x){
    sz[x]=1;
    for(int i=first[x];i;i=e[i].next){
        int now=e[i].to;
        if(now==fa[x]) continue;
        fa[now]=x;  dep[now]=dep[x]+1;
        f1(now);  sz[x]+=sz[now];
        if(sz[now]>sz[son[x]]) son[x]=now;
    }
}
void f2(int x,int tp){
    top[x]=tp; id[x]=idp++;
    if(son[x]) f2(son[x],tp);
    for(int i=first[x];i;i=e[i].next){
        int now=e[i].to;
        if(now!=fa[x]&&now!=son[x]) f2(now,now);
    }
}
int ly,s[3*M],mx[3*M];
void modify(int x,int w){
    s[x+ly]=w; mx[x+ly]=w;
    for(x=(x+ly)>>1;x;x>>=1) s[x]=s[x<<1]+s[x<<1^1],mx[x]=max(mx[x<<1],mx[x<<1^1]);
}
int push_max(int l,int r){
    int ans=-1e8;
    for(l=l+ly-1,r=r+ly+1;r-l!=1;l>>=1,r>>=1){
        if(~l&1) ans=max(ans,mx[l^1]); 
        if(r&1)  ans=max(ans,mx[r^1]);
    }
    return ans;
}
int Qmax(int x,int y){
    int ans=-1e8;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        ans=max(ans,push_max(id[top[x]],id[x]));
        x=fa[top[x]];
    }
    if(id[x]>id[y]) swap(x,y);
    ans=max(ans,push_max(id[x],id[y]));
    return ans;
}
int push_sum(int l,int r){
    int sum=0;
    for(l=l+ly-1,r=r+ly+1;r-l!=1;l>>=1,r>>=1){
        if(~l&1) sum+=s[l^1];
        if(r&1)  sum+=s[r^1];
    }
    return sum;
}
int Qsum(int x,int y){
    int sum=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        sum+=push_sum(id[top[x]],id[x]);
        x=fa[top[x]];
    }
    if(id[x]>id[y]) swap(x,y);
    sum+=push_sum(id[x],id[y]);
    return sum;
}
int main(){
    int x,y; 
    n=read(); ly=1; while(ly<=n+2) ly<<=1;
    for(int i=1;i<n;i++) x=read(),y=read(),insert(x,y);
    f1(1); f2(1,1);
    for(int i=1;i<=n;i++) x=read(),modify(id[i],x);
    m=read();
    for(int i=1;i<=m;i++){
        scanf("%s",ch); x=read(); y=read();
        if(ch[0]=='C') modify(id[x],y);
        else if(ch[1]=='M') printf("%d\n",Qmax(x,y));
        else printf("%d\n",Qsum(x,y));
    }
    return 0;
}
View Code

3——lct(link-cut-tree)

#include<cstdio>
#include<cstring>
#include<algorithm>
#define LL long long
using namespace std;
const int M=50007;
int read(){
    int ans=0,f=1,c=getchar();
    while(c<'0'||c>'9'){if(c=='-') f=-1; c=getchar();}
    while(c>='0'&&c<='9'){ans=ans*10+(c-'0'); c=getchar();}
    return ans*f;
}
int n,m,c[M][2],fa[M],a[M],b[M];
LL v[M],sum[M],mx[M];
bool rev[M];
bool isrt(int x){return c[fa[x]][0]!=x&&c[fa[x]][1]!=x;}
void up(int x){
    if(!x) return ;
    mx[x]=v[x]; sum[x]=v[x];
    int l=c[x][0],r=c[x][1];
    if(l) mx[x]=max(mx[x],mx[l]),sum[x]+=sum[l];
    if(r) mx[x]=max(mx[x],mx[r]),sum[x]+=sum[r];
}
void down(int x){
    if(!rev[x]) return ;
    rev[x]=0;
    int l=c[x][0],r=c[x][1];
    if(l) swap(c[l][0],c[l][1]),rev[l]^=1;
    if(r) swap(c[r][0],c[r][1]),rev[r]^=1;
}
void rotate(int x){
    int y=fa[x],z=fa[y],l=0,r=1;
    if(c[y][1]==x) l=1,r=0;
    if(!isrt(y)) c[z][c[z][1]==y]=x;
    fa[y]=x; fa[x]=z; fa[c[x][r]]=y;
    c[y][l]=c[x][r]; c[x][r]=y;
    up(y); up(x);
}
int st[M],top=0,S;
void splay(int x){
    st[++top]=x; for(int i=x;!isrt(i);i=fa[i]) st[++top]=fa[i];
    while(top) down(st[top--]);
    while(!isrt(x)){
        int y=fa[x],z=fa[y];
        if(!isrt(y)){
            if(c[z][0]==y^c[y][0]==x) rotate(x);
            else rotate(y);
        }
        rotate(x);
    }
}
void acs(int x0){
    for(int x=x0,y=0;x;splay(x),c[x][1]=y,up(x),y=x,x=fa[x]);
    splay(x0);
}
void mrt(int x){acs(x); swap(c[x][0],c[x][1]); rev[x]^=1;}
void link(int x,int y){mrt(x); fa[x]=y;}
void spl(int x,int y){mrt(x); acs(y);}
int main()
{
    int x,y;
    n=read();
    for(int i=1;i<n;i++) a[i]=read(),b[i]=read();
    for(int i=1;i<=n;i++) v[i]=read();
    for(int i=1;i<n;i++) link(a[i],b[i]);
    m=read();
    char ch[15];
    for(int i=1;i<=m;i++){
        scanf("%s",ch); x=read(); y=read();
        if(ch[1]=='H') acs(x),v[x]=y;
        if(ch[1]=='S') spl(x,y),printf("%lld\n",sum[y]);
        if(ch[1]=='M') spl(x,y),printf("%lld\n",mx[y]);
    }
    return 0;
}
View Code

 

posted @ 2017-10-03 08:09  友人Aqwq  阅读(164)  评论(0编辑  收藏  举报