[BZOJ2243][SDOI2011]染色(树链剖分)

[传送门]

树链剖分就行了,注意线段树上颜色的合并

Code

 

#include <cstdio>
#include <algorithm>
#define N 100010
#define MID int mid=(l+r)>>1,ls=id<<1,rs=id<<1|1
#define len (r-l+1)
using namespace std;

struct tree{
	int lc,rc,sum,tag;
	tree(){lc=rc=tag=-1;sum=0;}
	friend tree operator +(tree a,tree b){
		if(a.lc==-1) return b;
		if(b.lc==-1) return a;
		tree c;
		c.lc=a.lc,c.rc=b.rc;
		c.sum=a.sum+b.sum-(a.rc==b.lc?1:0);
		return c;
	}
}T[N*4];
struct info{int to,nex;}e[N*2];
int n,m,tot,head[N],cnt,A[N];
int tid[N],dep[N],son[N],fa[N],sz[N],tp[N],tw[N];

inline int read(){
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

inline void Link(int u,int v){
	e[++tot].nex=head[u];head[u]=tot;e[tot].to=v;
}

void dfs(int u,int pre){
	sz[u]=1;
	for(int i=head[u],mx=0;i;i=e[i].nex){
		int v=e[i].to;
		if(v==pre) continue;
		fa[v]=u;
		dep[v]=dep[u]+1;
		dfs(v,u);
		sz[u]+=sz[v];
		if(sz[v]>mx){son[u]=v;mx=sz[v];}
	}
}

void dddfs(int u,int top){
	tp[u]=top;
	tid[u]=++cnt;
	tw[cnt]=A[u];
	if(!son[u]) return;
	
	dddfs(son[u],top);
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].to;
		if(v==fa[u]||v==son[u]) continue;
		dddfs(v,v);
	}
}

void build(int l,int r,int id){
	if(l==r){T[id].sum=1;T[id].lc=T[id].rc=tw[l];return;}
	MID;
	build(l,mid,ls);
	build(mid+1,r,rs);
	T[id]=T[ls]+T[rs];
}

void Init(){
	n=read(),m=read();
	for(int i=1;i<=n;A[i++]=read());
	for(int i=1;i<n;++i){
		int u=read(),v=read();
		Link(u,v),Link(v,u);
	}
	dfs(1,0);
	dddfs(1,1);
	build(1,n,1);
}

inline void pushdown(int l,int r,int id){
	int &tmp=T[id].tag;
	if(tmp==-1) return;
	MID;
	T[ls].lc=T[ls].rc=T[rs].lc=T[rs].rc=tmp;
	T[ls].sum=T[rs].sum=1;
	T[ls].tag=T[rs].tag=tmp;
	tmp=-1;
}

int query(int l,int r,int id,int L,int R){
	if(L<=l&&r<=R) return T[id].sum;
	pushdown(l,r,id);
	MID;
	int res=0;
	if(R<=mid) res+=query(l,mid,ls,L,R);
	else if(L>mid) res+=query(mid+1,r,rs,L,R);
	else res+=query(l,mid,ls,L,R),res+=query(mid+1,r,rs,L,R),res-=(T[ls].rc==T[rs].lc)?1:0;
	return res;
}

int qDot(int l,int r,int id,int x){
	if(l==r&&l==x) return T[id].lc;
	pushdown(l,r,id);
	MID;
	if(x<=mid) return qDot(l,mid,ls,x);
	else return qDot(mid+1,r,rs,x);
}

inline int qRange(int u,int v){
	int res=0;
	while(tp[u]!=tp[v]){
		if(dep[tp[u]]<dep[tp[v]]) swap(u,v);
		res+=query(1,n,1,tid[tp[u]],tid[u]);
		int x=qDot(1,n,1,tid[tp[u]]),y=qDot(1,n,1,tid[fa[tp[u]]]);
		if(x==y) --res;
		u=fa[tp[u]];
	}
	if(dep[u]>dep[v]) swap(u,v);
	res+=query(1,n,1,tid[u],tid[v]);
	return res;
}

void update(int l,int r,int id,int L,int R,int x){
	if(L<=l&&r<=R){
		T[id].sum=1;
		T[id].lc=T[id].rc=T[id].tag=x;
		return;
	}
	pushdown(l,r,id);
	MID;
	if(L<=mid) update(l,mid,ls,L,R,x);
	if(R>mid) update(mid+1,r,rs,L,R,x);
	T[id]=T[ls]+T[rs];
}

void updRange(int u,int v,int x){
	while(tp[u]!=tp[v]){
		if(dep[tp[u]]<dep[tp[v]]) swap(u,v);
		update(1,n,1,tid[tp[u]],tid[u],x);
		u=fa[tp[u]];
	}
	if(dep[u]>dep[v]) swap(u,v);
	update(1,n,1,tid[u],tid[v],x);
}

inline void solve(){
	char ch;
	while(m--){
		for(ch=getchar();ch!='C'&&ch!='Q';ch=getchar());
		if(ch=='Q'){
			int u=read(),v=read();
			printf("%d\n",qRange(u,v));
		}else{
			int u=read(),v=read(),x=read();
			updRange(u,v,x);
		}
	}
}

int main(){Init();solve();}

 

posted @ 2018-05-09 19:49  void_f  阅读(174)  评论(0编辑  收藏  举报