[SDOI2011]染色 BZOJ2243 树链剖分+线段树
分析:
区间合并,lcol是左端点的颜色编号,rcol是右端点的颜色编号,那么我们向上合并的时候,如果左儿子的rcol等于右儿子的lcol那么区间的sum--。
另外,如果重链顶的颜色等于重链顶的父节点的颜色,那么ans--;
附上代码:
#include <cstdio> #include <algorithm> #include <cmath> #include <cstring> #include <cstdlib> #include <queue> #include <iostream> using namespace std; #define N 100005 #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 int sum[N<<2],lcol[N<<2],rcol[N<<2],cov[N<<2]; int head[N],cnt,dep[N],anc[N],fa[N],siz[N],son[N]; int idx[N],a[N],p[N],tims,n,Q; struct node { int to,next; }e[N<<1]; void add(int x,int y) { e[cnt].to=y; e[cnt].next=head[x]; head[x]=cnt++; return ; } void dfs1(int x,int from) { fa[x]=from,siz[x]=1,dep[x]=dep[from]+1; for(int i=head[x];i!=-1;i=e[i].next) { int to1=e[i].to; if(to1!=from) { dfs1(to1,x); siz[x]+=siz[to1]; if(siz[son[x]]<siz[to1])son[x]=to1; } } } void dfs2(int x,int top) { idx[x]=++tims; p[tims]=x; anc[x]=top; if(son[x])dfs2(son[x],top); for(int i=head[x];i!=-1;i=e[i].next) { int to1=e[i].to; if(to1!=fa[x]&&to1!=son[x]) { dfs2(to1,to1); } } } void PushUp(int rt) { lcol[rt]=lcol[rt<<1];rcol[rt]=rcol[rt<<1|1]; sum[rt]=sum[rt<<1|1]+sum[rt<<1]; if(lcol[rt<<1|1]==rcol[rt<<1])sum[rt]--; return ; } void PushDown(int rt) { if(cov[rt]) { cov[rt<<1]=cov[rt<<1|1]=lcol[rt<<1]=rcol[rt<<1]=lcol[rt<<1|1]=rcol[rt<<1|1]=cov[rt]; sum[rt<<1]=sum[rt<<1|1]=1; cov[rt]=0; } } void build(int l,int r,int rt) { if(l==r) { sum[rt]=1; lcol[rt]=rcol[rt]=a[p[l]]; return ; } int m=(l+r)>>1; build(lson); build(rson); PushUp(rt); } void Update(int L,int R,int c,int l,int r,int rt) { if(L<=l&&r<=R) { lcol[rt]=rcol[rt]=cov[rt]=c; sum[rt]=1; return ; } PushDown(rt); int m=(l+r)>>1; if(m>=L)Update(L,R,c,lson); if(m<R)Update(L,R,c,rson); PushUp(rt); } int query(int L,int R,int l,int r,int rt) { if(L<=l&&r<=R)return sum[rt]; PushDown(rt); int m=(l+r)>>1; int ret=0,vis=0; if(m>=L) { vis++; ret+=query(L,R,lson); } if(m<R) { vis++; ret+=query(L,R,rson); } if(rcol[rt<<1]==lcol[rt<<1|1]&&vis==2)ret--; return ret; } int query_col(int x,int l,int r,int rt) { if(l==r) { return lcol[rt]; } PushDown(rt); int m=(l+r)>>1; if(m>=x)return query_col(x,lson); else return query_col(x,rson); } int get_lca_query(int x,int y) { int ret=0; while(anc[x]!=anc[y]) { if(dep[anc[x]]<dep[anc[y]])swap(x,y); ret+=query(idx[anc[x]],idx[x],1,n,1); int l=query_col(idx[anc[x]],1,n,1); int r=query_col(idx[fa[anc[x]]],1,n,1); if(l==r)ret--; x=fa[anc[x]]; } if(dep[x]>dep[y])swap(x,y); ret+=query(idx[x],idx[y],1,n,1); return ret; } void get_lca_Update(int x,int y,int c) { while(anc[x]!=anc[y]) { if(dep[anc[x]]<dep[anc[y]])swap(x,y); Update(idx[anc[x]],idx[x],c,1,n,1); x=fa[anc[x]]; } if(dep[x]>dep[y])swap(x,y); Update(idx[x],idx[y],c,1,n,1); return ; } char s[20]; int main() { memset(head,-1,sizeof(head)); scanf("%d%d",&n,&Q); for(int i=1;i<=n;i++) { scanf("%d",&a[i]); } for(int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); add(x,y); add(y,x); } dfs1(1,0); dfs2(1,1); build(1,n,1); while(Q--) { int x,y,z; scanf("%s%d%d",s,&x,&y); if(s[0]=='Q') { printf("%d\n",get_lca_query(x,y)); }else { scanf("%d",&z); get_lca_Update(x,y,z); } } return 0; }