[SDOI2011]染色 BZOJ2243 树链剖分+线段树

分析:

区间合并,lcol是左端点的颜色编号,rcol是右端点的颜色编号,那么我们向上合并的时候,如果左儿子的rcol等于右儿子的lcol那么区间的sum--。

另外,如果重链顶的颜色等于重链顶的父节点的颜色,那么ans--;

附上代码:

#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <queue>
#include <iostream>
using namespace std;
#define N 100005
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
int sum[N<<2],lcol[N<<2],rcol[N<<2],cov[N<<2];
int head[N],cnt,dep[N],anc[N],fa[N],siz[N],son[N];
int idx[N],a[N],p[N],tims,n,Q;
struct node
{
	int to,next;
}e[N<<1];
void add(int x,int y)
{
	e[cnt].to=y;
	e[cnt].next=head[x];
	head[x]=cnt++;
	return ;
}
void dfs1(int x,int from)
{
	fa[x]=from,siz[x]=1,dep[x]=dep[from]+1;
	for(int i=head[x];i!=-1;i=e[i].next)
	{
		int to1=e[i].to;
		if(to1!=from)
		{
			dfs1(to1,x);
			siz[x]+=siz[to1];
			if(siz[son[x]]<siz[to1])son[x]=to1;
		}
	}
}
void dfs2(int x,int top)
{
	idx[x]=++tims;
	p[tims]=x;
	anc[x]=top;
	if(son[x])dfs2(son[x],top);
	for(int i=head[x];i!=-1;i=e[i].next)
	{
		int to1=e[i].to;
		if(to1!=fa[x]&&to1!=son[x])
		{
			dfs2(to1,to1);
		}
	}
}
void PushUp(int rt)
{
	lcol[rt]=lcol[rt<<1];rcol[rt]=rcol[rt<<1|1];
	sum[rt]=sum[rt<<1|1]+sum[rt<<1];
	if(lcol[rt<<1|1]==rcol[rt<<1])sum[rt]--;
	return ;
}
void PushDown(int rt)
{
	if(cov[rt])
	{
		cov[rt<<1]=cov[rt<<1|1]=lcol[rt<<1]=rcol[rt<<1]=lcol[rt<<1|1]=rcol[rt<<1|1]=cov[rt];
		sum[rt<<1]=sum[rt<<1|1]=1;
		cov[rt]=0;
	}
}
void build(int l,int r,int rt)
{
	if(l==r)
	{
		sum[rt]=1;
		lcol[rt]=rcol[rt]=a[p[l]];
		return ;
	}
	int m=(l+r)>>1;
	build(lson);
	build(rson);
	PushUp(rt);
}
void Update(int L,int R,int c,int l,int r,int rt)
{
	if(L<=l&&r<=R)
	{
		lcol[rt]=rcol[rt]=cov[rt]=c;
		sum[rt]=1;
		return ;
	}
	PushDown(rt);
	int m=(l+r)>>1;
	if(m>=L)Update(L,R,c,lson);
	if(m<R)Update(L,R,c,rson);
	PushUp(rt);
}
int query(int L,int R,int l,int r,int rt)
{
	if(L<=l&&r<=R)return sum[rt];
	PushDown(rt);
	int m=(l+r)>>1;
	int ret=0,vis=0;
	if(m>=L)
	{
		vis++;
		ret+=query(L,R,lson);
	}
	if(m<R)
	{
		vis++;
		ret+=query(L,R,rson);
	}
	if(rcol[rt<<1]==lcol[rt<<1|1]&&vis==2)ret--;
	return ret;
}
int query_col(int x,int l,int r,int rt)
{
	if(l==r)
	{
		return lcol[rt];
	}
	PushDown(rt);
	int m=(l+r)>>1;
	if(m>=x)return query_col(x,lson);
	else return query_col(x,rson);
}
int get_lca_query(int x,int y)
{
	int ret=0;
	while(anc[x]!=anc[y])
	{
		if(dep[anc[x]]<dep[anc[y]])swap(x,y);
		ret+=query(idx[anc[x]],idx[x],1,n,1);
		int l=query_col(idx[anc[x]],1,n,1);
		int r=query_col(idx[fa[anc[x]]],1,n,1);
		if(l==r)ret--;
		x=fa[anc[x]];
	}
	if(dep[x]>dep[y])swap(x,y);
	ret+=query(idx[x],idx[y],1,n,1);
	return ret;
}
void get_lca_Update(int x,int y,int c)
{
	while(anc[x]!=anc[y])
	{
		if(dep[anc[x]]<dep[anc[y]])swap(x,y);
		Update(idx[anc[x]],idx[x],c,1,n,1);
		x=fa[anc[x]];
	}
	if(dep[x]>dep[y])swap(x,y);
	Update(idx[x],idx[y],c,1,n,1);
	return ;
}
char s[20];
int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d%d",&n,&Q);
	for(int i=1;i<=n;i++)
	{
		scanf("%d",&a[i]);
	}
	for(int i=1;i<n;i++)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		add(x,y);
		add(y,x);
	}
	dfs1(1,0);
	dfs2(1,1);
	build(1,n,1);
	while(Q--)
	{
		int x,y,z;
		scanf("%s%d%d",s,&x,&y);
		if(s[0]=='Q')
		{
			printf("%d\n",get_lca_query(x,y));
		}else
		{
			scanf("%d",&z);
			get_lca_Update(x,y,z);
		}
	}
	return 0;
}

  

posted @ 2018-05-15 20:14  Winniechen  阅读(182)  评论(0编辑  收藏  举报