bzoj2243[SDOI2011]染色 树链剖分+线段树

2243: [SDOI2011]染色

Time Limit: 20 Sec  Memory Limit: 512 MB
Submit: 9012  Solved: 3375
[Submit][Status][Discuss]

Description

给定一棵有n个节点的无根树和m个操作,操作有2类:

1、将节点a到节点b路径上所有点都染成颜色c;

2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),

如“112221”由3段组成:“11”、“222”和“1”。

请你写一个程序依次完成这m个操作。

Input

第一行包含2个整数n和m,分别表示节点数和操作数;

第二行包含n个正整数表示n个节点的初始颜色

下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。

下面 行每行描述一个操作:

“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;

“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。

Output

对于每个询问操作,输出一行答案。

Sample Input

6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5

Sample Output

3
1
2

HINT

数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。

Source

第一轮day1

 

树链剖分+线段树
操作还是常规的区间修改 但是要注意区间合并的时候,颜色段数量的合并
在树的路径上合并区间时,有点不好处理区间左右端点,于
是我们把u->v拆成u->lca lca->v
这样区间更新端点就是同向的,即每次查询可以把一端统一地作为更新端点
由于u->lca lca->v左右端点必定重合 所以 最后答案减一


sb的我update忘了pushup调了1h 

 

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#define ls u<<1
#define rs ls|1
#define ll long long
#define N 100050
using namespace std;
int n,m,tot,cnt,a[N],fa[N][20],tp[N],siz[N],hd[N];
int dep[N],tid[N],son[N],v[N];
struct edge{int v,next;}e[N<<1];
struct node{int l,r,sum,lz;}t[N<<2];
void adde(int u,int v){
    e[++tot].v=v;
    e[tot].next=hd[u];
    hd[u]=tot;
}
void dfs1(int u,int pre){
    dep[u]=dep[pre]+1;fa[u][0]=pre;siz[u]=1;
    for(int j=1;(1<<j)<dep[u];j++)
    fa[u][j]=fa[fa[u][j-1]][j-1];
    for(int i=hd[u];i;i=e[i].next){
        int v=e[i].v;
        if(v==pre)continue;
        dfs1(v,u);siz[u]+=siz[v];
        if(siz[v]>siz[son[u]])son[u]=v;
    }
}
void dfs2(int u,int anc){
    if(!u)return;
    tid[u]=++cnt;v[cnt]=a[u];tp[u]=anc;
    dfs2(son[u],anc);
    for(int i=hd[u];i;i=e[i].next){
        int v=e[i].v;
        if(v==fa[u][0]||v==son[u])continue;
        dfs2(v,v);
    }
}
int lca(int x,int y){
    if(x==y)return x;
    if(dep[x]<dep[y])swap(x,y);
    for(int i=18;~i;i--)
    if(dep[fa[x][i]]>=dep[y])x=fa[x][i];
    if(x==y)return x;
    for(int i=18;~i;i--){
        if(fa[x][i]==fa[y][i])continue;
        x=fa[x][i];y=fa[y][i];
    }
    return fa[x][0];
}

void pushup(int u){
    t[u].sum=t[ls].sum+t[rs].sum;
    t[u].l=t[ls].l;t[u].r=t[rs].r;
    if(t[ls].r==t[rs].l)t[u].sum--;
}
void pushdown(int u){
    if(t[u].lz==-1)return;
    t[ls].l=t[ls].r=t[ls].lz=t[u].lz;
    t[rs].l=t[rs].r=t[rs].lz=t[u].lz;
    t[ls].sum=t[rs].sum=1;
    t[u].lz=-1;
}
void build(int u,int l,int r){
    t[u].lz=-1;
    if(l==r){
        t[u].l=t[u].r=v[l];
        t[u].sum=1;return;
    }
    int mid=l+r>>1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    pushup(u);
}
node query(int u,int L,int R,int l,int r){
    if(l<=L&&R<=r)return t[u];
    pushdown(u);node ret;
    int mid=L+R>>1,fg=0;
    if(l<=mid)ret=query(ls,L,mid,l,r),fg=1;
    if(r>mid){
        node tmp=query(rs,mid+1,R,l,r);
        if(!fg)ret=tmp;
        else{
            ret.sum+=tmp.sum;
            if(ret.r==tmp.l)ret.sum--;
            ret.r=tmp.r;
        }
    }
    return ret;
}
void update(int u,int L,int R,int l,int r,int c){
    if(l<=L&&R<=r){
        t[u].lz=t[u].l=t[u].r=c;
        t[u].sum=1;return;
    }
    pushdown(u);
    int mid=L+R>>1;
    if(l<=mid)update(ls,L,mid,l,r,c);
    if(r>mid)update(rs,mid+1,R,l,r,c);
    pushup(u);
}
node jump(int x,int y,int val,int op){
    int fx=tp[x],fy=tp[y];node ret;
    ret.l=ret.r=ret.sum=0;
    while(fx!=fy){
        if(op)update(1,1,cnt,tid[fx],tid[x],val);
        else{
            node tmp=query(1,1,cnt,tid[fx],tid[x]);
            ret.sum+=tmp.sum;
            if(tmp.r==ret.l)ret.sum--;
            ret.l=tmp.l;
        }
        x=fa[fx][0];fx=tp[x];
    }
    if(dep[x]>dep[y])swap(x,y);
    if(op)update(1,1,cnt,tid[x],tid[y],val);
    else{
        node tmp=query(1,1,cnt,tid[x],tid[y]);
        ret.sum+=tmp.sum;
        if(tmp.r==ret.l)ret.sum--;
        ret.l=tmp.l;
    }
    return ret;
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
    scanf("%d",&a[i]);
    for(int i=1;i<n;i++){
        int a,b;
        scanf("%d%d",&a,&b);
        adde(a,b);adde(b,a);
    }
    dfs1(1,0);dfs2(1,1);
    build(1,1,cnt);
    int x,y,c;char s[2];
    while(m--){
        scanf("%s%d%d",s,&x,&y);
        int anc=lca(x,y);
        if(s[0]=='C'){
            scanf("%d",&c);
            jump(x,anc,c,1);
            jump(y,anc,c,1);
        }
        else{
            node t1,t2;
            t1=jump(x,anc,0,0);
            t2=jump(y,anc,0,0);
            printf("%d\n",t1.sum+t2.sum-1);
        }
    }
    return 0;
}
posted @ 2017-12-18 10:30  _wsy  阅读(170)  评论(0编辑  收藏  举报