BZOJ 2243 [SDOI2011] 染色 (树剖+线段树)

题意

一颗1e5的数,1e5的询问,两种操作:
1.给x到y的所有点染成一种颜色
2.问x到y的路径上有多少段相同的颜色

思路

先考虑序列上,12操作均可以使用线段树来处理
需要注意的是,询问时只有在线段树上往两边同时下传的时候,才需要判断是否x需要答案-1
所以写的时候分三段写比较方便

if(mid>=y)return ask(x,y,lson);
else if(mid<x)return ask(x,y,rson);
else{
    int tmp = 0;
    if(R[lc]==L[rc])tmp=1;
    return ask(x,y,lson)+ask(x,y,rson)-tmp;
}

而修改时要统一pushup,不要忘记lazy标记下传(因为这个改了好久)
同时,题目中颜色可以为0,所以初试lazy标记要设置成-1

而在树上的操作,需要再加个lca和树剖
求x到y的答案,即链[lca(x,y),x]的答案+链[lca(x,y),y]的答案-1
因为改成了树剖,所以需要再写一个单点查询颜色的函数,将几条链的答案拼起来
修改同理

代码

int n;
int b[maxn];
int a[maxn<<2],lz[maxn<<2];
int L[maxn<<2],R[maxn<<2];

int dep[maxn],fa[maxn][18];
int lg[maxn];

int sz[maxn],son[maxn];
int rk[maxn],top[maxn];
int id[maxn];
int cnt;
struct LCA{
    int n;
    vector<int>g[maxn];
    
    void dfs(int x, int lst){
        dep[x] = dep[lst] + 1;
        fa[x][0] = lst;
        sz[x]=1;
        for(int i = 1; (1<<i) <= dep[x]; i++){
            fa[x][i] = fa[fa[x][i-1]][i-1];
        }
        for(int i = 0; i < (int)g[x].size(); i++){
            int y = g[x][i];
            if(y==lst) continue;
            dfs(y, x);
            sz[x]+=sz[y];
            if(sz[y]>sz[son[x]])son[x]=y;
        }
        return;
    }
    void dfs1(int x, int t){
        top[x]=t;id[x]=++cnt;rk[cnt]=x;
        if(!son[x])return;
        dfs1(son[x],t);
        for(int i = 0; i < (int)g[x].size(); i++){
            int y = g[x][i];
            if(y!=fa[x][0]&&y!=son[x])dfs1(y,y);
        }
    }
    void init(int root, int N){
        n=N;cnt=0;
        for(int i = 1; i <= n; i++){
            lg[i] = lg[i-1]+(1<<lg[i-1]==i);
        }
        dfs(root,0);
        dfs1(root,root);
    }
    void add(int x, int y){
        g[x].pb(y);g[y].pb(x);
    }

    int lca(int x, int y){
        if(dep[x] > dep[y])swap(x,y);
        while(dep[x] != dep[y]){
            if(lg[dep[y]-dep[x]]-1>=0)y = fa[y][lg[dep[y]-dep[x]]-1];
            else y = fa[y][0];
        }
        if(x==y) return x;
        for(int i = lg[dep[y]]; i >= 0; i--){
            if(fa[x][i] != fa[y][i]){
                x = fa[x][i];
                y = fa[y][i];
            }
        }
        return fa[x][0];
    }
}lca;
void pushup(int root){
    a[root]=a[lc]+a[rc];
    if(R[lc]==L[rc])a[root]--;
    L[root]=L[lc];R[root]=R[rc];
}
void build(int l, int r, int root){
    int mid = l+r>>1;
    lz[root]=-1;
    if(l==r){
        a[root]=1;
        L[root]=R[root]=b[rk[l]];
        return;
    }
    build(lson);build(rson);
    pushup(root);
}
void pushdown(int l, int r, int root){
    int mid = l+r>>1;
    if(lz[root]!=-1){
        a[lc]=1;lz[lc]=L[lc]=R[lc]=lz[root];
        a[rc]=1;lz[rc]=L[rc]=R[rc]=lz[root];
        lz[root]=-1;
    }
}
void update(int x, int y, int c, int l, int r, int root){
    int mid = l+r>>1;
    if(x<=l&&r<=y){
        a[root]=1;L[root]=R[root]=c;lz[root]=c;
        return;
    }
    pushdown(l,r,root);
    if(x<=mid)update(x,y,c,lson);
    if(mid<y)update(x,y,c,rson);
    pushup(root);
    return;
}
//int ans;
PI tmp;
PI TT;
int ask(int x, int y, int l, int r, int root){
    
    int mid = l+r>>1;
    int ans = 0;
    if(x<=l&&r<=y){return a[root];}
    pushdown(l,r,root);
    if(mid>=y)return ask(x,y,lson);
    else if(mid<x)return ask(x,y,rson);
    else{
        int tmp = 0;
        if(R[lc]==L[rc])tmp=1;
        return ask(x,y,lson)+ask(x,y,rson)-tmp;
    }
    pushup(root);
}
int color(int x, int l, int r, int root){
    int mid = l+r>>1;
    if(l==r)return L[root];
    pushdown(l,r,root);
    if(x<=mid)return color(x,lson);
    else return color(x,rson);
    pushup(root);
}

int q;
///======
int sum(int x, int y){
    int sum = 0;
    while(top[x]!=top[y]){
        sum += ask(id[top[x]],id[x],1,n,1);
        if(color(id[top[x]],1,n,1)==color(id[fa[top[x]][0]],1,n,1))sum--;
        x=fa[top[x]][0];
    }
    sum+=ask(id[y],id[x],1,n,1);
    return sum;
}
void change(int x, int y, int w){
    while(top[x]!=top[y]){
        update(id[top[x]],id[x],w,1,n,1);
        x=fa[top[x]][0];
    }
    update(id[y],id[x],w,1,n,1);
}
///======
int main() {
    mem(lz,-1);
    scanf("%d %d", &n,&q);
    for(int i = 1; i <= n; i++){
        scanf("%d", &b[i]);
    }
    for(int i = 1; i < n; i++){
        int x,y;
        scanf("%d %d", &x, &y);
        lca.add(x,y);
    }
    lca.init(1,n);
    build(1,n,1);
    while(q--){
        char op[5];
        int x,y,w;
        scanf("%s",op+1);
        if(op[1]=='Q'){
            scanf("%d %d", &x, &y);
            int L = lca.lca(x,y);
            printf("%d\n",sum(x,L)+sum(y,L)-1);
        }
        else{
            scanf("%d %d %d", &x, &y, &w);
            int L = lca.lca(x,y);
            change(x,L,w);change(y,L,w);
        }
    }
    return 0;
}
posted @ 2020-05-14 12:34  wrjlinkkkkkk  阅读(181)  评论(0编辑  收藏  举报