【洛谷P7735】轻重边

题目

题目链接:https://www.luogu.com.cn/problem/P7735
小 W 有一棵 \(n\) 个结点的树,树上的每一条边可能是轻边或者重边。接下来你需要对树进行 \(m\) 次操作,在所有操作开始前,树上所有边都是轻边。操作有以下两种:

  1. 给定两个点 \(a\)\(b\),首先对于 \(a\)\(b\) 路径上的所有点 \(x\)(包含 \(a\)\(b\)),你要将与 \(x\) 相连的所有边变为轻边。然后再将 \(a\)\(b\) 路径上包含的所有边变为重边。
  2. 给定两个点 \(a\)\(b\),你需要计算当前 \(a\)\(b\) 的路径上一共包含多少条重边。

\(n,m\leq 10^5\)

思路

修改操作的话可以先把所有与这条链相连的边都覆盖为 \(0\),然后再把这一条链上的点覆盖为 \(1\)
把边的颜色扔到点上,重链剖分,然后考虑一次修改操作中其中一条重链 \((x,y)\) 应该如何处理。
\(z\)\(x\) 的重儿子,那么需要覆盖为 \(0\) 的部分有:

  • \((z,y)\) 这条重链的所有点。
  • \((x,y)\) 这条重链的所有点的轻儿子。
  • \(x\) 的重儿子。

所以考虑把重链和轻儿子的贡献分开算。维护两棵线段树,第一棵就是重链剖分常规的线段树,把每一条重链放在一个区间中,第二棵线段树则是把每一个点的所有轻儿子放到一个区间中,且每一条重链的轻儿子也要在区间中。这样重新编号应该不难实现。
剩余的部分就很裸了。时间复杂度 \(O(n\log^2 n)\)


但是这样写常数很大,oisdoaiu 大爷给我讲了一种常数似乎小很多的做法。

代码

#include <bits/stdc++.h>
using namespace std;

const int N=100010;
int Q,n,m,tot,head[N],id1[N],id2[N],id3[N],siz[N],fa[N],son[N],cnt[N],dep[N],top[N];

int read()
{
	int d=0; char ch=getchar();
	while (!isdigit(ch)) ch=getchar();
	while (isdigit(ch)) d=(d<<3)+(d<<1)+ch-48,ch=getchar();
	return d;
}

struct edge
{
	int next,to;
}e[N*2];

void add(int from,int to)
{
	e[++tot]=(edge){head[from],to};
	head[from]=tot;
}

void dfs1(int x,int fat)
{
	fa[x]=fat; dep[x]=dep[fat]+1; siz[x]=1; son[x]=0;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fat)
		{
			dfs1(v,x); siz[x]+=siz[v];
			if (siz[v]>siz[son[x]]) son[x]=v;
		}
	}
}

void dfs2(int x,int tp)
{
	top[x]=tp; id1[x]=++tot;
	if (son[x]) dfs2(son[x],tp);
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa[x] && v!=son[x]) dfs2(v,v);
	}
}

void dfs3(int x)
{
	id3[x]=tot+1; cnt[x]=0;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=son[x] && v!=fa[x])
			id2[v]=++tot,cnt[x]++;
	}
	if (son[x]) dfs3(son[x]);
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=son[x] && v!=fa[x]) dfs3(v);
	}
}

struct SegTree
{
	int sum[N*4],lazy[N*4];
	
	void pushdown(int x,int l,int r)
	{
		if (lazy[x]!=-1)
		{
			int mid=(l+r)>>1;
			sum[x*2]=lazy[x]*(mid-l+1); lazy[x*2]=lazy[x];
			sum[x*2+1]=lazy[x]*(r-mid); lazy[x*2+1]=lazy[x];
			lazy[x]=-1;
		}
	}
	
	void update(int x,int l,int r,int ql,int qr,int v)
	{
		if (ql>qr) return;
		if (ql<=l && qr>=r)
			return (void)(sum[x]=v*(r-l+1),lazy[x]=v);
		pushdown(x,l,r);
		int mid=(l+r)>>1;
		if (ql<=mid) update(x*2,l,mid,ql,qr,v);
		if (qr>mid) update(x*2+1,mid+1,r,ql,qr,v);
		sum[x]=sum[x*2]+sum[x*2+1];
	}
	
	int query(int x,int l,int r,int ql,int qr)
	{
		if (ql>qr) return 0;
		if (ql<=l && qr>=r) return sum[x];
		pushdown(x,l,r);
		int mid=(l+r)>>1,res=0;
		if (ql<=mid) res+=query(x*2,l,mid,ql,qr);
		if (qr>mid) res+=query(x*2+1,mid+1,r,ql,qr);
		return res;
	}
}seg1,seg2;

void clear(int x,int y)
{
	for (;top[x]!=top[y];x=fa[top[x]])
	{
		if (dep[top[x]]<dep[top[y]]) swap(x,y);
		seg1.update(1,1,n,id1[top[x]]+1,id1[x],0);
		seg2.update(1,1,n,id3[top[x]],id3[x]+cnt[x]-1,0);
		if (son[x]) seg1.update(1,1,n,id1[son[x]],id1[son[x]],0);
	}
	if (dep[x]<dep[y]) swap(x,y);
	seg1.update(1,1,n,id1[y]+1,id1[x],0);
	seg2.update(1,1,n,id3[y],id3[x]+cnt[x]-1,0);
	if (son[x]) seg1.update(1,1,n,id1[son[x]],id1[son[x]],0);
	if (y!=1 && son[fa[y]]==y) seg1.update(1,1,n,id1[y],id1[y],0);
	if (y!=1 && son[fa[y]]!=y) seg2.update(1,1,n,id2[y],id2[y],0);
}

void update(int x,int y)
{
	for (;top[x]!=top[y];x=fa[top[x]])
	{
		if (dep[top[x]]<dep[top[y]]) swap(x,y);
		seg1.update(1,1,n,id1[top[x]]+1,id1[x],1);
		seg2.update(1,1,n,id2[top[x]],id2[top[x]],1);	
	}
	if (dep[x]<dep[y]) swap(x,y);
	seg1.update(1,1,n,id1[y]+1,id1[x],1);
}

void query(int x,int y)
{
	int ans=0;
	for (;top[x]!=top[y];x=fa[top[x]])
	{
		if (dep[top[x]]<dep[top[y]]) swap(x,y);
		ans+=seg1.query(1,1,n,id1[top[x]]+1,id1[x]);
		ans+=seg2.query(1,1,n,id2[top[x]],id2[top[x]]);
	}
	if (dep[x]<dep[y]) swap(x,y);
	ans+=seg1.query(1,1,n,id1[y]+1,id1[x]);
	cout<<ans<<"\n";
}

void prework()
{
	memset(head,-1,sizeof(head));
	memset(seg1.lazy,-1,sizeof(seg1.lazy));
	memset(seg2.lazy,-1,sizeof(seg2.lazy));
	memset(seg1.sum,0,sizeof(seg1.sum));
	memset(seg2.sum,0,sizeof(seg2.sum));
	tot=0;
}

int main()
{
	Q=read();
	while (Q--)
	{
		prework();
		n=read(); m=read();
		for (int i=1,x,y;i<n;i++)
		{
			x=read(); y=read();
			add(x,y); add(y,x);
		}
		tot=0; dfs1(1,0); dfs2(1,1);
		tot=0; dfs3(1);
		while (m--)
		{
			int opt=read(),x=read(),y=read();
			if (opt==1) clear(x,y),update(x,y);
			if (opt==2) query(x,y);
		}
	}
	return 0;
}
posted @ 2021-07-26 22:43  stoorz  阅读(59)  评论(0编辑  收藏  举报