BZOJ 2243 [SDOI2011]染色 树链剖分+线段树

题意

给定一棵有n个节点的无根树和m个操作,操作有2类:

1、将节点a到节点b路径上所有点都染成颜色c;

2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),

如“112221”由3段组成:“11”、“222”和“1”。

请你写一个程序依次完成这m个操作。

分析

用线段树维护下区间左右端点的颜色,区间不同颜色段的个数,区间合并时接口的颜色相同要减一,

树剖一下,跑一遍就行了

Code

#include<bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define lson l,mid,p<<1
#define rson mid+1,r,p<<1|1
using namespace std;
typedef long long ll;
const int inf=1e9;
const int maxn=3e5+10;
int n,q;
int a[maxn];
vector<int>g[maxn];
int sz[maxn],f[maxn],d[maxn],top[maxn],son[maxn],id[maxn],p[maxn],tot;
int val[maxn<<2],c1[maxn<<2],c2[maxn<<2],tag[maxn<<2];
struct ppo{
    int val,c1,c2;
};
void pp(int p){
    val[p]=val[p<<1]+val[p<<1|1]-(c2[p<<1]==c1[p<<1|1]);
    c1[p]=c1[p<<1];c2[p]=c2[p<<1|1];
}
void bd(int l,int r,int p){
    tag[p]=-1;
    if(l==r){
        val[p]=1;c1[p]=c2[p]=a[id[l]];return;
    }int mid=l+r>>1;
    bd(lson);bd(rson);pp(p);
}
void pd(int l,int r,int p,int k){val[p]=1;c1[p]=c2[p]=k;tag[p]=k;};
void up(int dl,int dr,int l,int r,int p,int k){
    if(l>=dl&&r<=dr){
        val[p]=1;c1[p]=c2[p]=k;
        tag[p]=k;return;
    }int mid=l+r>>1;
    if(~tag[p]){pd(lson,tag[p]);pd(rson,tag[p]);tag[p]=-1;}
    if(dl<=mid) up(dl,dr,lson,k);
    if(dr>mid) up(dl,dr,rson,k);
    pp(p);
}
ppo mer(ppo a,ppo b){
    a.val=a.val+b.val-(a.c2==b.c1);
    a.c2=b.c2;
    return a;
}
ppo qy(int dl,int dr,int l,int r,int p){
    if(l>=dl&&r<=dr){
        return ppo{val[p],c1[p],c2[p]};
    }int mid=l+r>>1;
    if(~tag[p]){pd(lson,tag[p]);pd(rson,tag[p]);tag[p]=-1;}
    if(dr<=mid){
        return qy(dl,dr,lson);
    }else if(dl>mid){
        return qy(dl,dr,rson);
    }else{
        return mer(qy(dl,dr,lson),qy(dl,dr,rson));
    }
}
void dfs1(int u){
    sz[u]=1;d[u]=d[f[u]]+1;
    for(int i=0;i<g[u].size();i++){
        int x=g[u][i];
        if(x==f[u]) continue;
        f[x]=u;dfs1(x);
        sz[u]+=sz[x];
        if(sz[x]>sz[son[u]]) son[u]=x;
    }
}
void dfs2(int u,int t){
    top[u]=t;p[u]=++tot;id[tot]=u;
    if(son[u]) dfs2(son[u],t);
    for(int i=0;i<g[u].size();i++){
        int x=g[u][i];
        if(x==f[u]||x==son[u]) continue;
        dfs2(x,x);
    }
}
void col(int x,int y,int k){
    while(top[x]!=top[y]){
        if(d[top[x]]<d[top[y]]) swap(x,y);
        up(p[top[x]],p[x],1,n,1,k);
        x=f[top[x]];
    }
    if(d[x]<d[y]) swap(x,y);
    up(p[y],p[x],1,n,1,k);
}
int cal(int x,int y){
    ppo a=ppo{0,-1,-1},b=ppo{0,-1,-1};
    while(top[x]!=top[y]){
        if(d[top[x]]<d[top[y]]){
             ppo res=qy(p[top[y]],p[y],1,n,1);
             if(b.c1==-1) b=res;
             else b=mer(res,b);
             y=f[top[y]];
        }else{
             ppo res=qy(p[top[x]],p[x],1,n,1);
             if(a.c1==-1) a=res;
             else a=mer(res,a);
             x=f[top[x]];
        }
    }
    if(d[x]<d[y]){
        ppo res=qy(p[x],p[y],1,n,1);
        b=mer(res,b);
    }else{
        ppo res=qy(p[y],p[x],1,n,1);
        a=mer(res,a);
    }
    int ret=a.val+b.val-(a.c1==b.c1);
    return ret;
}
int main(){
    scanf("%d%d",&n,&q);
    for(int i=1;i<=n;i++){
        scanf("%d",&a[i]);
    }
    for(int i=1,a,b;i<n;i++){
        scanf("%d%d",&a,&b);
        g[a].pb(b);g[b].pb(a);
    }
    dfs1(1);dfs2(1,1);bd(1,n,1);
    while(q--){
        int a,b,c;
        char s[3];
        scanf("%s",s);
        if(s[0]=='C'){
            scanf("%d%d%d",&a,&b,&c);
            col(a,b,c);
        }else{
            scanf("%d%d",&a,&b);
            printf("%d\n",cal(a,b));
        }
    }
    return 0;
}
posted @ 2019-05-23 16:40  xyq0220  阅读(124)  评论(0编辑  收藏  举报