3.22 树链剖分

3.22 树链剖分好题分享

P4374 [USACO18OPEN] Disruption P

这道题我做的时候还是紫的,现在蓝了。

题意

给一棵树,再给若干条新边,每条新边有一个边权。对于每条树上的边,求它被砍掉之后使得树重新联通的最小代价。

题解

枚举每条树边直接求显然是很困难的,所以我们考虑计算每条非树边对树边的贡献。

不难发现,每条非树边会跟树上的若干条边形成唯一的一个环,那么这条边就可能在环上任意一条树边被砍掉之后成为替代它的边,也就是会对这些边产生贡献。

于是做法便明晰了:对树进行树链剖分,建线段树维护最小值,对于每条非树边,更新其两个端点之间的树链的答案。

#include <cstdio>
#define N 100005
#define M 800005
int min(int x,int y) {return x<y?x:y;}
int max(int x,int y) {return x>y?x:y;}
int n,m;
int hed[N],tal[N],nxt[N],cnte;
void adde(int u,int v)
{
	tal[++cnte]=v;
	nxt[cnte]=hed[u];
	hed[u]=cnte;
}
struct sgt
{
	int d[M],tg[M],ls[M],rs[M],idx;
	#define mid (lb+rb>>1)
	#define pushup(x) d[x]=max(d[ls[x]],d[rs[x]])
	int upd(int x,int y)
	{
		if(x==-1) return y;
		if(y==-1) return x;
		return min(x,y);
	}
	int newnode()
	{
		int nx=++idx;
		d[nx]=tg[nx]=-1;
		return nx;
	}
	void pushdown(int x)
	{
		if(tg[x]==-1) return;
		if(!ls[x]) ls[x]=newnode();
		if(!rs[x]) rs[x]=newnode();
		d[ls[x]]=upd(d[ls[x]],tg[x]);
		d[rs[x]]=upd(d[rs[x]],tg[x]);
		tg[ls[x]]=upd(tg[ls[x]],tg[x]);
		tg[rs[x]]=upd(tg[rs[x]],tg[x]);
		tg[x]=-1;
	}
	void modify(int &x,int k,int l,int r,int lb,int rb)
	{
		if(!x) x=newnode();
		if(l<=lb&&rb<=r)
		{
			d[x]=upd(d[x],k);
			tg[x]=upd(tg[x],k);
			return;
		}
		pushdown(x);
		if(l<=mid) modify(ls[x],k,l,r,lb,mid);
		if(r>mid) modify(rs[x],k,l,r,mid+1,rb);
		pushup(x);
	}
	int query(int x,int t,int lb,int rb)
	{
		if(!x) return -1;
		if(lb==rb) return d[x];
		pushdown(x);
		if(t<=mid) return query(ls[x],t,lb,mid);
		else return query(rs[x],t,mid+1,rb);
	}
	#undef mid
	#undef pushup
} tr;
int rt,dfn[N],dep[N];
struct HLD
{
	int fa[N],son[N],siz[N],top[N],idx;
	void dfs1(int x)
	{
		siz[x]=1;
		for(int i=hed[x];i;i=nxt[i])
			if(!siz[tal[i]])
			{
				fa[tal[i]]=x,dep[tal[i]]=dep[x]+1;
				dfs1(tal[i]);
				siz[x]+=siz[tal[i]];
				if(siz[tal[i]]>siz[son[x]]) son[x]=tal[i];
			}
	}
	void dfs2(int x,int tp)
	{
		if(!x) return;
		dfn[x]=++idx;
		top[x]=tp;
		dfs2(son[x],tp);
		for(int i=hed[x];i;i=nxt[i])
			if(!top[tal[i]]) dfs2(tal[i],tal[i]);
	}
	void init()
	{
		dfs1(1);
		dfs2(1,1);
	}
	void modify(int u,int v,int w)
	{
		while(top[u]!=top[v])
		{
			if(dep[top[u]]<dep[top[v]]) {int tmp=u;u=v,v=tmp;}
			tr.modify(rt,w,dfn[top[u]],dfn[u],1,n);
			u=fa[top[u]];
		}
		if(dep[u]>dep[v]) {int tmp=u;u=v,v=tmp;};
		tr.modify(rt,w,dfn[u]+1,dfn[v],1,n);
	}
} hld;
int t1[N],t2[N];
int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<n;i++)
	{
		scanf("%d%d",&t1[i],&t2[i]);
		adde(t1[i],t2[i]);
		adde(t2[i],t1[i]);
	}
	hld.init();
	for(int i=1;i<=m;i++)
	{
		int u,v,w;
		scanf("%d%d%d",&u,&v,&w);
		hld.modify(u,v,w);
	}
	for(int i=1;i<n;i++)
	{
		if(dep[t2[i]]>dep[t1[i]])
		{
			int tmp=t2[i];
			t2[i]=t1[i],t1[i]=tmp;
		}
		printf("%d\n",tr.query(rt,dfn[t1[i]],1,n));
	}
}

P3401 洛谷树

题意

树,单点修改,查询树链上每个子区间的异或和之和。

题解

首先考虑序列上的这个问题该怎么解决。

注意到题目最后一行写了:

对于 \(100\%\) 的数据,所有边权小于等于 \(1023\)

这使我们联想到按位维护。

对于第 \(k\) 个二进制位,每个子区间的异或和的这一位,要么是 \(1\),要么是 \(0\)。所以要求每个子区间的异或和之和,只需要求有多少个子区间的异或和是 \(1\),然后把这个数量乘以 \(2^k\) 加入答案。

于是问题得到了简化:给一 \(01\) 序列,维护单点修改,求区间有多少个子区间异或和为 \(1\)

这令我们联想到小白逛公园这道题。于是可以采用与这题类似的做法,用线段树维护每个节点对应区间的前后缀信息。

具体地,记录它本身的答案、异或和,以及它有多少个异或和为 \(1/0\) 的前后缀,于是 pushup 就能写出来了。

回到本题,由于是树上问题,所以使用树链剖分。需要注意边权转点权的问题。

注意树剖 query 的时候需要先求两边的链,再 reverse 一下,最后合并答案。

#include <cstdio>
#define N 30005
#define int long long
int n,q,cnte,hed[N],tal[N<<1],wt[N<<1],a[N],nxt[N<<1],rt[15];
void adde(int u,int v,int w) {tal[++cnte]=v,wt[cnte]=w,nxt[cnte]=hed[u],hed[u]=cnte;}
struct seq
{
	int sum,d0,d1,l0,l1,r0,r1;
	seq(int x=0):sum(x),d0(!x),d1(x),l0(!x),l1(x),r0(!x),r1(x) {}
	seq rev()
	{
		seq ret;
		ret.sum=sum,ret.d0=d0,ret.d1=d1,ret.l0=r0,ret.l1=r1,ret.r0=l0,ret.r1=l1;
		return ret;
	}
} ans[2][15];
seq merge(seq x,seq y)
{
	if(x.sum==-1) return y;
	if(y.sum==-1) return x;
	seq ret;
	ret.sum=x.sum^y.sum;
	ret.d0=x.d0+y.d0+x.r0*y.l0+x.r1*y.l1;
	ret.d1=x.d1+y.d1+x.r0*y.l1+x.r1*y.l0;
	ret.l0=x.l0,ret.l1=x.l1;
	if(x.sum) ret.l0+=y.l1,ret.l1+=y.l0;
	else ret.l0+=y.l0,ret.l1+=y.l1;
	ret.r0=y.r0,ret.r1=y.r1;
	if(y.sum) ret.r0+=x.r1,ret.r1+=x.r0;
	else ret.r0+=x.r0,ret.r1+=x.r1;
	return ret;
}
int dfn[N],li[N],dep[N],fa[N],son[N],siz[N],top[N],idx;
struct sgt
{
	seq d[N<<5];
	int ls[N<<5],rs[N<<5],id;
	#define mid (lb+rb>>1)
	int build(int t,int lb,int rb)
	{
		int x=++id;
		if(lb==rb) {d[x]=a[li[lb]]>>t&1;return x;}
		ls[x]=build(t,lb,mid),rs[x]=build(t,mid+1,rb);
		d[x]=merge(d[ls[x]],d[rs[x]]);
		return x;
	}
	void modify(int x,int t,int lb,int rb)
	{
		if(lb==rb)
		{
			d[x]=d[x].sum^1;
			return;
		}
		(t<=mid)?modify(ls[x],t,lb,mid):modify(rs[x],t,mid+1,rb);
		d[x]=merge(d[ls[x]],d[rs[x]]);
	}
	seq query(int x,int l,int r,int lb,int rb)
	{
		if(l>r) return -1;
		if(l<=lb&&rb<=r) return d[x];
		if(r<=mid) return query(ls[x],l,r,lb,mid);
		if(l>mid) return query(rs[x],l,r,mid+1,rb);
		return merge(query(ls[x],l,r,lb,mid),query(rs[x],l,r,mid+1,rb));
	}
	#undef mid
} tr;
void dfs1(int x)
{
	siz[x]=1;
	for(int i=hed[x];i;i=nxt[i]) if(!siz[tal[i]])
	{
		fa[tal[i]]=x,dep[tal[i]]=dep[x]+1,a[tal[i]]=wt[i];
		dfs1(tal[i]);
		siz[x]+=siz[tal[i]];
		if(siz[tal[i]]>siz[son[x]]) son[x]=tal[i];
	}
}
void dfs2(int x,int tp)
{
	if(!x) return;
	li[dfn[x]=++idx]=x;
	dfs2(son[x],top[x]=tp);
	for(int i=hed[x];i;i=nxt[i]) if(!top[tal[i]]) dfs2(tal[i],tal[i]);
}
int query(int x,int y)
{
	for(int i=0;i<15;i++) ans[0][i]=ans[1][i]=-1;
	while(top[x]!=top[y])
	{
		if(dep[top[x]]>dep[top[y]])
		{
			for(int i=0;i<15;i++) 
				ans[0][i]=merge(tr.query(rt[i],dfn[top[x]],dfn[x],1,n),ans[0][i]);
			x=fa[top[x]];
		}
		else
		{
			for(int i=0;i<15;i++) 
				ans[1][i]=merge(tr.query(rt[i],dfn[top[y]],dfn[y],1,n),ans[1][i]);
			y=fa[top[y]];
		}
	}
	if(dep[x]>dep[y]) for(int i=0;i<15;i++)
		ans[0][i]=merge(tr.query(rt[i],dfn[y]+1,dfn[x],1,n),ans[0][i]);
	else for(int i=0;i<15;i++)
		ans[1][i]=merge(tr.query(rt[i],dfn[x]+1,dfn[y],1,n),ans[1][i]);
	for(int i=0;i<15;i++)
		ans[0][i]=merge(ans[0][i].rev(),ans[1][i]);
	int ret=0;
	for(int i=0;i<15;i++) if(ans[0][i].sum!=-1) ret+=ans[0][i].d1<<i;
	return ret;
}
main()
{
	scanf("%lld%lld",&n,&q);
	for(int i=1,u,v,w;i<n;i++) scanf("%lld%lld%lld",&u,&v,&w),adde(u,v,w),adde(v,u,w);
	dfs1(1),dfs2(1,1);
	for(int i=0;i<15;i++) rt[i]=tr.build(i,1,n);
	while(q--)
	{
		int op,u,v,w;
		scanf("%lld%lld%lld",&op,&u,&v);
		if(op==1) printf("%lld\n",query(u,v));
		if(op==2)
		{
			scanf("%lld",&w);
			if(fa[u]==v) {int tmp=u;u=v,v=tmp;}
			int y=a[v]^w;
			for(int i=0;i<15;i++) if(y>>i&1) tr.modify(rt[i],dfn[v],1,n);
			a[v]=w;
		}
	}
}
posted @ 2025-04-08 15:09  整齐的艾萨克  阅读(12)  评论(0)    收藏  举报