【BZOJ】2243 [SDOI2011]染色

【算法】树链剖分+线段树

【题解】

树链剖分算法:http://www.cnblogs.com/onioncyc/p/6207462.html

定义线段树结构体有l,r,lc,rc,sum,data。

lc表示左端颜色,rc表示右端颜色,sum表示颜色种类,data表示区间置为同一个数的标记。

修改的时候要上推和下传,查询的时候要下传。

我的写法是打lazy标记的时候顺便把子树的其它参数都修改完毕,方便直接调用。

访问到有lazy标记的子树时把标记下传给左右子树并修改左右子树的其他参数。

左右端颜色相同的处理方法见:http://blog.csdn.net/u011645923/article/details/43087133

还是注意树链剖分后操作要使用新编号pos[i]。

#include<cstdio>
#include<cctype>
#include<algorithm>
using namespace std;
int read()
{
    char c;int s=0,t=1;
    while(!isdigit(c=getchar()))if(c=='-')t=-1;
    do{s=s*10+c-'0';}while(isdigit(c=getchar()));
    return s*t;
}
const int maxn=300010;
int first[maxn],size[maxn],deep[maxn],f[maxn],top[maxn],pos[maxn],LC,RC,n,m,tot,dfsnum,a[maxn];
struct edge{int u,v,from;}e[maxn*3];
struct tree{int lc,rc,sum,l,r,data;}t[maxn*3];
void insert(int u,int v)
{tot++;e[tot].u=u;e[tot].v=v;e[tot].from=first[u];first[u]=tot;}
void dfs1(int x,int fa)
{
    size[x]=1;
    for(int i=first[x];i;i=e[i].from)
     if(e[i].v!=fa)
      {
        int y=e[i].v;
        f[y]=x;
        deep[y]=deep[x]+1;
        dfs1(y,x);
        size[x]+=size[y];
      }
}
void dfs2(int x,int tp,int fa)
{
    int k=0;
    pos[x]=++dfsnum;
    top[x]=tp;
    for(int i=first[x];i;i=e[i].from)
     if(e[i].v!=fa&&size[e[i].v]>size[k])k=e[i].v;
    if(k==0)return;
    dfs2(k,tp,x);
    for(int i=first[x];i;i=e[i].from)
     if(e[i].v!=fa&&e[i].v!=k)dfs2(e[i].v,e[i].v,x);
}
void build(int k,int l,int r)
{
    t[k].l=l;t[k].r=r;
    if(l==r){t[k].lc=0;t[k].rc=0;t[k].sum=0;t[k].data=0;return;}
     else
      {
          int mid=(l+r)>>1;
          build(k<<1,l,mid);
          build(k<<1|1,mid+1,r);
      }
}
void pushdown(int k)
{
    if(t[k].data)
     {
         t[k<<1].sum=t[k<<1|1].sum=1;
         t[k<<1].data=t[k<<1|1].data=t[k].data;
         t[k<<1].lc=t[k<<1].rc=t[k].data;
         t[k<<1|1].lc=t[k<<1|1].rc=t[k].data;
     }
    t[k].data=0;
}
void pushup(int k)
{
    t[k].lc=t[k<<1].lc;
    t[k].rc=t[k<<1|1].rc;
    t[k].sum=t[k<<1].sum+t[k<<1|1].sum;
    if(t[k<<1].rc==t[k<<1|1].lc)t[k].sum--;
}
void change(int k,int l,int r,int x)//区间修改需要上推&&下传 
{
    pushdown(k);
    int left=t[k].l,right=t[k].r;
    if(l<=left&&right<=r){t[k].data=x;t[k].sum=1;t[k].lc=t[k].rc=x;}
     else
      {
          int mid=(left+right)>>1;
          if(l<=mid)change(k<<1,l,r,x);
          if(r>mid)change(k<<1|1,l,r,x);
          pushup(k);
      }
}
int ask(int k,int l,int r)//区间查询只需要下传 
{
    pushdown(k);
    int left=t[k].l,right=t[k].r;
    if(l==left)LC=t[k].lc;
    if(r==right)RC=t[k].rc;
    if(l<=left&&right<=r){return t[k].sum;}
     else
      {
          int mid=(left+right)>>1,sums=0,ok=0;
          if(l<=mid)sums=ask(k<<1,l,r),ok++;
          if(r>mid)sums+=ask(k<<1|1,l,r),ok++;
          if(ok==2&&t[k<<1].rc==t[k<<1|1].lc)sums--;//只取一边的话就不需要判断了 
          return sums;
      }
}
void update(int x,int y,int z)
{
    while(top[x]!=top[y])
     {
         if(deep[top[x]]<deep[top[y]])swap(x,y);
         change(1,pos[top[x]],pos[x],z);//!!!
         x=f[top[x]];
     }
    if(pos[x]>pos[y])swap(x,y);
    change(1,pos[x],pos[y],z);
}
int solve(int x,int y)
{
    int sums=0,ansx=0,ansy=0;//分别表示x和y的左端点颜色 
    while(top[x]!=top[y])
     {
         if(deep[top[x]]<deep[top[y]])swap(x,y),swap(ansx,ansy);
         sums+=ask(1,pos[top[x]],pos[x]);
         if(RC==ansx)sums--;
         ansx=LC; 
         x=f[top[x]];
     }
    if(pos[x]>pos[y])swap(x,y),swap(ansx,ansy);
    sums+=ask(1,pos[x],pos[y]);
    if(ansx==LC)sums--;
    if(ansy==RC)sums--; 
    return sums;
}
int main()
{
    n=read();m=read();
    for(int i=1;i<=n;i++)a[i]=read();
    for(int i=1;i<n;i++)
     {
        int u=read(),v=read();
        insert(u,v);
        insert(v,u);
     }
    dfs1(1,-1);dfs2(1,1,-1);
    build(1,1,n);
    for(int i=1;i<=n;i++)update(i,i,a[i]+1);//颜色+1避免0的问题 
    for(int i=1;i<=m;i++)
     {
         char c=getchar();
         while(!(c=='C'||c=='Q'))c=getchar();
         if(c=='C')
          {
              int x=read(),y=read(),z=read();
            update(x,y,z+1);
          }
         else
          {
              int x=read(),y=read();
              printf("%d\n",solve(x,y));
          }
     }
    return 0;
}
View Code
posted @ 2016-12-22 19:11  ONION_CYC  阅读(215)  评论(0编辑  收藏  举报