20260411模拟赛

20260411模拟赛

相等树链

题面:

给你两棵树,问多少点集在两棵树上均为链。\(1\leq n\leq 2\times 10^5\)

题解:

\(p_t(x,y)\) 表示第 \(t\) 棵树上路径 \((x,y)\) 的点集。

对一个树点分治,对于当前分治重心 \(u\),记 \(s_t(x)\) 表示 \(p_t(x,u)/\{u\}\)。考虑所有经过 \(u\) 的路径 \((x,y)\),其在另一棵树上的路径是 \((z,w)\),满足 \(s_1(x)\oplus s_1(y)=s_2(z)\oplus s_2(w)\),这可以异或哈希。尝试哈希表计数,是否能将等式移项成分别只和 \(x,y\) 有关的形式。

考虑对于每个 \(x\) 求出 \(s_1(x)\)\((z,w)\) 中两个方向上离 \(u\) 最远的点分别是哪个,具体的对于每个 \(x\) 求出一个集合 \(T(x)\) 表示 \(s_1(x)\) 的点在第二棵树中离 \(u\) 最远的互相之间没有祖先后代关系的那些点。

如果这样的点大于两个,则肯定不能成链。否则考虑 \(z,w\) 这两个点在 \(s_1(x)\) 中还是 \(s_1(y)\) 中。

  • 都在 \(s_1(x)\) 中,那么 \(T(x)=\{z,w\}\)\(s_1(x)\oplus s_2(z)\oplus s_2(w)=s_1(y)\),左右分别只跟 \(x,y\) 有关。
  • 都在 \(s_1(y)\) 中,同理 \(T(y)=\{z,w\}\)\(s_1(x)=s_2(z)\oplus s_2(w)\oplus s_1(y)\)
  • 一边一个,不妨设 \(z\in T(x),w\in T(y)\)\(s_1(x)\oplus s_2(z)=s_2(w)\oplus s_1(y)\)

发现会算错一种 \(z,w\) 在第二棵树属于 \(u\) 的同一子树的情况,所以可以对 \(u\) 的不同子树染不同颜色进行哈希表查询即可。

代码
#include<bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/hash_policy.hpp>
#define ll long long
#define fir first
#define sec second
#define ump gp_hash_table<ll,int>
using namespace std;
using namespace __gnu_pbds;

inline int read(){
	int s=0,k=1;
	char c=getchar();
	while(c>'9'||c<'0'){
		if(c=='-') k=-1;
		c=getchar();
	}
	while(c>='0'&&c<='9'){
		s=(s<<3)+(s<<1)+(c^48);
		c=getchar();
	}
	return s*k;
}

mt19937_64 rnd(time(0));
const int N=2e5+5;
int n,dep[N],col[N],C[N];
ll w[N],val[N],ans;
bool nok[N],exi[N];

namespace B{
	int head[N],cnt;
	struct edge{
		int v,nxt;
	}e[N<<1];
	
	void add(int u,int v){
		e[++cnt].v=v;
		e[cnt].nxt=head[u];
		head[u]=cnt;
	}
	
	void dfs(int x,int fa,int c){
		val[x]=val[fa]^w[x];
		dep[x]=dep[fa]+1;
		col[x]=c; exi[x]=1;
		for(int i=head[x],v;i;i=e[i].nxt){
			v=e[i].v;
			if(v==fa||!nok[v]) continue;
			dfs(v,x,c);
		}
	}
	
	void del(int x,int fa){
		exi[x]=0;
		for(int i=head[x],v;i;i=e[i].nxt){
			v=e[i].v;
			if(v==fa||!nok[v]) continue;
			del(v,x);
		}
	}
	
	void clear(int x){
		for(int i=head[x],v;i;i=e[i].nxt){
			v=e[i].v;
			if(nok[v]) del(v,x);
		}
	}
	
	void sol(int x){
		dep[x]=1; val[x]=0;
		int tot=0;
		for(int i=head[x],v;i;i=e[i].nxt){
			v=e[i].v;
			if(nok[v]) dfs(v,x,++tot);
		}
	}
}

namespace A{
	int head[N],cnt,tot,rt,mx,siz[N];
	pair<int,int>a[N];
	ump X,Y,Z[N];
	ll dis[N];
	bool vis[N],rev[N];
	struct edge{
		int v,nxt;
	}e[N<<1];
	
	void add(int u,int v){
		e[++cnt].v=v;
		e[cnt].nxt=head[u];
		head[u]=cnt;
	}
	
	void dfz(int x,int fa){
		siz[x]=1;nok[x]=1;
		int num=0;
		for(int i=head[x],v;i;i=e[i].nxt){
			v=e[i].v;
			if(v==fa||vis[v]) continue;
			dfz(v,x);
			siz[x]+=siz[v];
			num=max(num,siz[v]);
		}
		num=max(num,tot-siz[x]);
		if(num<mx){
			mx=num;
			rt=x;
		}
	}
	
	int cmax(int x,int y){
		if(dep[x]>dep[y]) return x;
		else return y;
	}
	
	ll Xor(int x){
		return dis[x]^val[a[x].fir]^val[a[x].sec];
	}
	
	void dfs(int x,int fa,int A,int B){
		if(!exi[x]) return ;
		if(A&&B){
			if(col[x]!=col[A]&&col[x]!=col[B]) return ;
			if(col[x]==col[A]) A=cmax(A,x);
			if(col[x]==col[B]) B=cmax(B,x);
		}
		else if(A){
			if(col[x]==col[A]) A=cmax(A,x);
			else B=x;
		}
		else A=x;
		rev[x]=1;
		a[x]={A,B};
		dis[x]=dis[fa]^w[x];
		ans+=X[Xor(x)];
		ans+=Y[dis[x]];
		if(A){
			ans+=Z[0][dis[x]^val[A]];
			ans-=Z[col[A]][dis[x]^val[A]];	
		}
		if(B){
			ans+=Z[0][dis[x]^val[B]];
			ans-=Z[col[B]][dis[x]^val[B]];	
		}
		for(int i=head[x],v;i;i=e[i].nxt){
			v=e[i].v;
			if(v==fa||vis[v]) continue;
			dfs(v,x,A,B);
		}
	}
	
	void calc(int x,int fa){
		if(!rev[x]) return ;
		X[dis[x]]++;
		Y[Xor(x)]++;
		int A,B;
		tie(A,B)=a[x];
		if(A){
			Z[0][dis[x]^val[A]]++;
			Z[col[A]][dis[x]^val[A]]++;	
		}
		if(B){
			Z[0][dis[x]^val[B]]++;
			Z[col[B]][dis[x]^val[B]]++;
		}
		for(int i=head[x],v;i;i=e[i].nxt){
			v=e[i].v;
			if(v==fa||vis[v]) continue;
			calc(v,x);
		}
	}
	
	void del(int x,int fa){
		if(!rev[x]) return ;
		rev[x]=0;
		int A,B;
		tie(A,B)=a[x];
		if(A) Z[col[A]].clear();
		if(B) Z[col[B]].clear();
		a[x]={0,0};
		for(int i=head[x],v;i;i=e[i].nxt){
			v=e[i].v;
			if(v==fa||vis[v]) continue;
			del(v,x);
		}
	}
	
	void getsz(int x,int fa){
		siz[x]=1;nok[x]=0;
		for(int i=head[x],v;i;i=e[i].nxt){
			v=e[i].v;
			if(v==fa||vis[v]) continue;
			getsz(v,x);
			siz[x]+=siz[v];
		}
	}
	
	void sol(int x){
		vis[x]=1;
		B::sol(x);
		dis[x]=0; X[0]++;
		for(int i=head[x],v;i;i=e[i].nxt){
			v=e[i].v;
			if(vis[v]) continue;
			dfs(v,x,0,0);
			calc(v,x);
		}
		B::clear(x);
		X.clear();Y.clear();Z[0].clear();
		for(int i=head[x],v;i;i=e[i].nxt){
			v=e[i].v;
			if(vis[v]) continue;
			del(v,x);
		}
		getsz(x,0);
		for(int i=head[x],v;i;i=e[i].nxt){
			v=e[i].v;
			if(vis[v]) continue;
			tot=siz[v];mx=n+1,rt=0;
			dfz(v,x); sol(rt);
		}
	}
	
	void solve(){
		rt=0;mx=n+1,tot=n;
		dfz(1,0); sol(rt);
		printf("%lld\n",ans+n);
	}
}

int main(){
	// freopen("c.in","r",stdin);
	// freopen("c.out","w",stdout);
	n=read();
	for(int i=2;i<=n;i++){
		int x=read();
		A::add(x,i);A::add(i,x);
	}
	for(int i=2;i<=n;i++){
		int x=read();
		B::add(x,i);B::add(i,x);
	}
	for(int i=1;i<=n;i++) w[i]=rnd();
	A::solve();
	return 0;
}

posted @ 2026-04-13 15:43  programmingysx  阅读(10)  评论(0)    收藏  举报
Title