[LOJ2001] 树点染色

问题描述

Bob 有一棵 n 个点的有根树,其中 1 号点是根节点。Bob 在每个节点上涂了颜色,并且每个点上的颜色不同。

定义一条路径的权值是,这条路径上的点(包括起点和终点)共有多少种不同的颜色。

Bob 可能会进行这几种操作:

  • 1 x,把点 x 到根节点的路径上的所有的点染上一种没有用过的新颜色;
  • 2 x y,求 x 到 y 的路径的权值;
  • 3 x,在以 x 为根的子树中选择一个点,使得这个点到根节点的路径权值最大,求最大权值。

Bob 一共会进行 m 次操作。

输入格式

第一行两个数 n、m。
接下来 n−1 行,每行两个数 a、b 表示 a 和 b 之间有一条边。
接下来 m 行,表示操作,格式见题目描述。

输出格式

每当出现 23 操作,输出一行。

如果是 2 操作,输出一个数表示路径的权值。
如果是 3 操作,输出一个数表示权值的最大值。

样例输入

5 6
1 2
2 3
3 4
3 5
2 4 5
3 3
1 4
2 4 5
1 5
2 4 5

样例输出

3
4
2
2

数据范围

对所有数据,\(1\le n,m \le 10^5\)。。

解析

观察题目所给的操作,我们可以发现,树上的每一种颜色都是深度递增的连续的一段。我们接下来的处理方法都是基于这个性质的。

对于操作1,直接将一个点到根节点的路径染色,我们可以直接联想到LCT的access操作。因此我们只需用LCT维护颜色集合即可。

对于操作2,回答一条路径上的颜色种类,我们可以用LCA求两点之间路径长的方法,在LCT中维护每个点到根节点经过的颜色种类数。答案就是两端点的值之和减去两倍的LCA的值再加上1,因为重复减去了LCA上的颜色。

但如何维护呢?观察access的时候,对于每个被改变颜色的点,都会发生实变虚、虚变实的过程。那么,对于这个点发生改变后的实儿子,其对应的子树中的所有点到根节点经过的颜色种类都要减一;而虚儿子对应的要加一。但是,由于LCT改变了原树的结构,我们不能够用LCT直接维护子树信息,而是要单独用一棵线段树来进行修改与查询。每次LCT时,分别找到他的实儿子和虚儿子在原树中对应的子树根节点,然后利用dfs序进行修改即可。

对于操作3,在线段树中区间查询即可。

代码

#include <iostream>
#include <cstdio>
#define N 100002
using namespace std;
struct SegmentTree{
   int dat,add;
}t[N*4];
int head[N],ver[N*2],nxt[N*2],l;
int n,m,i,son[N][2],fa[N],in[N],out[N],dep[N],pos[N],f[N][21],cnt;
int read()
{
   char c=getchar();
   int w=0;
   while(c<'0'||c>'9') c=getchar();
   while(c<='9'&&c>='0'){
   	w=w*10+c-'0';
   	c=getchar();
   }
   return w;
}
void insert(int x,int y)
{
   l++;
   ver[l]=y;
   nxt[l]=head[x];
   head[x]=l;
}
void dfs(int x,int pre)
{
   in[x]=++cnt;pos[cnt]=x;
   f[x][0]=pre;
   dep[x]=dep[pre]+1;
   for(int i=head[x];i;i=nxt[i]){
   	int y=ver[i];
   	if(y!=pre) dfs(y,x);
   }
   out[x]=cnt;
}
void update(int p)
{
   t[p].dat=max(t[p*2].dat,t[p*2+1].dat);
}
void spread(int p)
{
   if(t[p].add){
   	t[p*2].dat+=t[p].add;t[p*2].add+=t[p].add;
   	t[p*2+1].dat+=t[p].add;t[p*2+1].add+=t[p].add;
   	t[p].add=0;
   }
}
void build(int p,int l,int r)
{
   if(l==r){
   	t[p].dat=dep[pos[l]];
   	return;
   }
   int mid=(l+r)/2;
   build(p*2,l,mid);
   build(p*2+1,mid+1,r);
   update(p);
}
void change(int p,int l,int r,int ql,int qr,int x)
{
   if(ql<=l&&r<=qr){
   	t[p].dat+=x;t[p].add+=x;
   	return;
   }
   int mid=(l+r)/2;
   spread(p);
   if(ql<=mid) change(p*2,l,mid,ql,qr,x);
   if(qr>mid) change(p*2+1,mid+1,r,ql,qr,x);
   update(p);
}
int ask1(int p,int l,int r,int x)
{
   if(l==r) return t[p].dat;
   int mid=(l+r)/2;
   spread(p);
   if(x<=mid) return ask1(p*2,l,mid,x);
   else return ask1(p*2+1,mid+1,r,x);
}
int ask2(int p,int l,int r,int ql,int qr)
{
   if(ql<=l&&r<=qr) return t[p].dat;
   int mid=(l+r)/2,ans=0;
   spread(p);
   if(ql<=mid) ans=max(ans,ask2(p*2,l,mid,ql,qr));
   if(qr>mid) ans=max(ans,ask2(p*2+1,mid+1,r,ql,qr));
   return ans;
}
void init()
{
   dfs(1,0);
   for(int i=1;i<=n;i++) fa[i]=f[i][0];
   for(int j=0;j<=19;j++){
   	for(int i=1;i<=n;i++) f[i][j+1]=f[f[i][j]][j];
   }
   build(1,1,n);
}
int LCA(int u,int v)
{
   if(dep[u]>dep[v]) swap(u,v);
   int tmp=dep[v]-dep[u];
   for(int i=0;(1<<i)<=tmp;i++){
   	if((1<<i)&tmp) v=f[v][i];
   }
   if(u==v) return u;
   for(int i=19;i>=0;i--){
   	if(f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i];
   }
   return f[u][0];
}
bool unroot(int x)
{
   return son[fa[x]][0]==x||son[fa[x]][1]==x;
}
void rotate(int x)
{
   int y=fa[x],z=fa[y],p=(son[y][1]==x),w=son[x][p^1];
   if(unroot(y)) son[z][son[z][1]==y]=x;
   son[x][p^1]=y;son[y][p]=w;
   if(w) fa[w]=y;
   fa[y]=x;fa[x]=z;
}
void splay(int x)
{
   int y=x,z;
   while(unroot(x)){
   	y=fa[x];z=fa[y];
   	if(unroot(y)) rotate((son[z][1]==y)^(son[y][1]==x)?x:y);
   	rotate(x);
   }
}
int findroot(int x)
{
   while(son[x][0]) x=son[x][0];
   return x;
}
void access(int x)
{
   for(int y=0;x;y=x,x=fa[x]){
   	splay(x);
   	if(son[x][1]){
   		int p=findroot(son[x][1]);
   		change(1,1,n,in[p],out[p],1);
   	}
   	if(y){
   		int p=findroot(y);
   		change(1,1,n,in[p],out[p],-1);
   	}
   	son[x][1]=y;
   }
}
int main()
{
   n=read();m=read();
   for(i=1;i<n;i++){
   	int u=read(),v=read();
   	insert(u,v);
   	insert(v,u);
   }
   init();
   for(i=1;i<=m;i++){
   	int op=read();
   	if(op==1){
   		int x=read();
   		access(x);
   	}
   	else if(op==2){
   		int x=read(),y=read(),lca=LCA(x,y);
   		printf("%d\n",ask1(1,1,n,in[x])+ask1(1,1,n,in[y])-2*ask1(1,1,n,in[lca])+1);
   	}
   	else{
   		int x=read();
   		printf("%d\n",ask2(1,1,n,in[x],out[x]));
   	}
   }
   return 0;
}
posted @ 2020-06-07 13:33  CJlzf  阅读(252)  评论(0)    收藏  举报