BZOJ2243 染色

树链剖分点更新,线段树区间更新,将某一段更新值为某个数C(颜色),区间查询,查询区间内有多少颜色段

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <algorithm>
  4 using namespace std;
  5 #define lson l,m,rt<<1
  6 #define rson m+1,r,rt<<1|1
  7 const int maxn = 100005;
  8 int siz[maxn],son[maxn],top[maxn],tid[maxn],fa[maxn],dep[maxn];
  9 int sum[maxn<<2],cl[maxn<<2],cr[maxn<<2],lz[maxn<<2];
 10 struct node
 11 {
 12     int v,next;
 13 }e[maxn*2];
 14 int head[maxn],val[maxn],cnt,lable;
 15 int n,m;
 16 void init()
 17 {
 18     memset(head,-1,sizeof(head));
 19     cnt = lable = 0;
 20 }
 21 void add(int u,int v)
 22 {
 23     e[cnt].v = v;
 24     e[cnt].next = head[u];
 25     head[u] = cnt++;
 26 }
 27 void find_heavy(int rt,int father,int depth)
 28 {
 29     int maxsize = 0;
 30     fa[rt] = father;
 31     son[rt] = 0;
 32     siz[rt] = 1;
 33     dep[rt] = depth;
 34     for(int i = head[rt];i!=-1;i = e[i].next)if(e[i].v!=father){
 35         find_heavy(e[i].v,rt,depth+1);
 36         siz[rt]+=siz[e[i].v];
 37         if(siz[e[i].v]>maxsize)
 38             maxsize = siz[e[i].v],son[rt] = e[i].v;
 39     }
 40 }
 41 void connect(int rt,int anc)
 42 {
 43     tid[rt] = ++lable;
 44     top[rt] = anc;
 45     if(son[rt])connect(son[rt],anc);
 46     for(int i = head[rt];i!=-1;i =e[i].next)
 47         if(e[i].v!=fa[rt]&&e[i].v!=son[rt])
 48             connect(e[i].v,e[i].v);
 49 }
 50 void pushup(int rt)
 51 {
 52     sum[rt] = sum[rt<<1]+sum[rt<<1|1];
 53     cl[rt] = cl[rt<<1];
 54     cr[rt] = cr[rt<<1|1];
 55     if(cr[rt<<1]==cl[rt<<1|1])sum[rt]--;
 56 }
 57 void pushdown(int rt)
 58 {
 59     if(lz[rt]!=-1){
 60         sum[rt<<1] = sum[rt<<1|1] = 1;
 61         cl[rt<<1] = cr[rt<<1] = lz[rt];
 62         cl[rt<<1|1] = cr[rt<<1|1] = lz[rt];
 63         lz[rt<<1] = lz[rt<<1|1] = lz[rt];
 64         lz[rt] = -1;
 65     }
 66 }
 67 void update(int L,int R,int c,int l,int r,int rt)
 68 {
 69     if(L<=l&&r<=R){
 70         cl[rt] = c;
 71         cr[rt] = c;
 72         sum[rt] = 1;
 73         lz[rt] = c;
 74         return;
 75     }
 76     pushdown(rt);
 77     int m = (l+r)>>1;
 78     if(L<=m)update(L,R,c,lson);
 79     if(m<R)update(L,R,c,rson);
 80     pushup(rt);
 81 }
 82 int query(int L,int R,int l,int r,int rt)
 83 {
 84     if(L<=l&&r<=R)return sum[rt];
 85     pushdown(rt);
 86     int m = (l+r)>>1,ret = 0,ok = -1;
 87     if(L<=m)ret+=query(L,R,lson),ok++;
 88     if(m<R)ret+=query(L,R,rson),ok++;
 89     if(ok==1){
 90         if(cr[rt<<1]==cl[rt<<1|1])ret--;
 91     }
 92     return ret;
 93 }
 94 
 95 int get(int pos,int l,int r,int rt)
 96 {
 97     if(l==r)return cl[rt];
 98     pushdown(rt);
 99     int m = (l+r)>>1;
100     if(pos<=m)return get(pos,lson);
101     else return get(pos,rson);
102 }
103 int getsum(int x,int y)
104 {
105     int ans = 0;
106     while(top[x]!=top[y])
107     {
108         if(dep[top[x]]<dep[top[y]])swap(x,y);
109         ans+=query(tid[top[x]],tid[x],1,n,1);
110         if(get(tid[top[x]],1,n,1)==get(tid[fa[top[x]]],1,n,1))ans--;
111         x = fa[top[x]];
112     }
113     if(dep[x]>dep[y])swap(x,y);
114     ans+=query(tid[x],tid[y],1,n,1);
115     return ans;
116 }
117 void change(int x,int y,int c)
118 {
119     while(top[x]!=top[y])
120     {
121         if(dep[top[x]]<dep[top[y]])swap(x,y);
122         update(tid[top[x]],tid[x],c,1,n,1);
123         x = fa[top[x]];
124     }
125     if(dep[x]>dep[y])swap(x,y);
126     update(tid[x],tid[y],c,1,n,1);
127 }
128 int main()
129 {
130    // freopen("in.txt","r",stdin);
131     while(~scanf("%d%d",&n,&m))
132     {
133         init();
134         for(int i = 1;i<=n;++i)scanf("%d",val+i);
135         for(int i = 1;i<n;++i){
136             int u,v;scanf("%d%d",&u,&v);
137             add(u,v);add(v,u);
138         }
139         find_heavy(1,1,1);
140         connect(1,1);
141         memset(lz,-1,sizeof(lz));
142         memset(sum,0,sizeof(sum));
143         for(int i = 1;i<=n;++i)update(tid[i],tid[i],val[i],1,n,1);
144         while(m--)
145         {
146             char s[2];int a,b,c;
147             scanf("%s",s);
148             if(s[0]=='C'){
149                 scanf("%d%d%d",&a,&b,&c);
150                 change(a,b,c);
151             }
152             else {
153                 scanf("%d%d",&a,&b);
154                 printf("%d\n",getsum(a,b));
155             }
156         }
157         puts("");
158     }
159     return 0;
160 }

 

posted on 2015-07-28 16:29  round_0  阅读(129)  评论(0编辑  收藏  举报

导航