20260411模拟赛
20260411模拟赛
相等树链
题面:
给你两棵树,问多少点集在两棵树上均为链。\(1\leq n\leq 2\times 10^5\)
题解:
记 \(p_t(x,y)\) 表示第 \(t\) 棵树上路径 \((x,y)\) 的点集。
对一个树点分治,对于当前分治重心 \(u\),记 \(s_t(x)\) 表示 \(p_t(x,u)/\{u\}\)。考虑所有经过 \(u\) 的路径 \((x,y)\),其在另一棵树上的路径是 \((z,w)\),满足 \(s_1(x)\oplus s_1(y)=s_2(z)\oplus s_2(w)\),这可以异或哈希。尝试哈希表计数,是否能将等式移项成分别只和 \(x,y\) 有关的形式。
考虑对于每个 \(x\) 求出 \(s_1(x)\) 在 \((z,w)\) 中两个方向上离 \(u\) 最远的点分别是哪个,具体的对于每个 \(x\) 求出一个集合 \(T(x)\) 表示 \(s_1(x)\) 的点在第二棵树中离 \(u\) 最远的互相之间没有祖先后代关系的那些点。
如果这样的点大于两个,则肯定不能成链。否则考虑 \(z,w\) 这两个点在 \(s_1(x)\) 中还是 \(s_1(y)\) 中。
- 都在 \(s_1(x)\) 中,那么 \(T(x)=\{z,w\}\),\(s_1(x)\oplus s_2(z)\oplus s_2(w)=s_1(y)\),左右分别只跟 \(x,y\) 有关。
- 都在 \(s_1(y)\) 中,同理 \(T(y)=\{z,w\}\),\(s_1(x)=s_2(z)\oplus s_2(w)\oplus s_1(y)\)。
- 一边一个,不妨设 \(z\in T(x),w\in T(y)\),\(s_1(x)\oplus s_2(z)=s_2(w)\oplus s_1(y)\)。
发现会算错一种 \(z,w\) 在第二棵树属于 \(u\) 的同一子树的情况,所以可以对 \(u\) 的不同子树染不同颜色进行哈希表查询即可。
代码
#include<bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/hash_policy.hpp>
#define ll long long
#define fir first
#define sec second
#define ump gp_hash_table<ll,int>
using namespace std;
using namespace __gnu_pbds;
inline int read(){
int s=0,k=1;
char c=getchar();
while(c>'9'||c<'0'){
if(c=='-') k=-1;
c=getchar();
}
while(c>='0'&&c<='9'){
s=(s<<3)+(s<<1)+(c^48);
c=getchar();
}
return s*k;
}
mt19937_64 rnd(time(0));
const int N=2e5+5;
int n,dep[N],col[N],C[N];
ll w[N],val[N],ans;
bool nok[N],exi[N];
namespace B{
int head[N],cnt;
struct edge{
int v,nxt;
}e[N<<1];
void add(int u,int v){
e[++cnt].v=v;
e[cnt].nxt=head[u];
head[u]=cnt;
}
void dfs(int x,int fa,int c){
val[x]=val[fa]^w[x];
dep[x]=dep[fa]+1;
col[x]=c; exi[x]=1;
for(int i=head[x],v;i;i=e[i].nxt){
v=e[i].v;
if(v==fa||!nok[v]) continue;
dfs(v,x,c);
}
}
void del(int x,int fa){
exi[x]=0;
for(int i=head[x],v;i;i=e[i].nxt){
v=e[i].v;
if(v==fa||!nok[v]) continue;
del(v,x);
}
}
void clear(int x){
for(int i=head[x],v;i;i=e[i].nxt){
v=e[i].v;
if(nok[v]) del(v,x);
}
}
void sol(int x){
dep[x]=1; val[x]=0;
int tot=0;
for(int i=head[x],v;i;i=e[i].nxt){
v=e[i].v;
if(nok[v]) dfs(v,x,++tot);
}
}
}
namespace A{
int head[N],cnt,tot,rt,mx,siz[N];
pair<int,int>a[N];
ump X,Y,Z[N];
ll dis[N];
bool vis[N],rev[N];
struct edge{
int v,nxt;
}e[N<<1];
void add(int u,int v){
e[++cnt].v=v;
e[cnt].nxt=head[u];
head[u]=cnt;
}
void dfz(int x,int fa){
siz[x]=1;nok[x]=1;
int num=0;
for(int i=head[x],v;i;i=e[i].nxt){
v=e[i].v;
if(v==fa||vis[v]) continue;
dfz(v,x);
siz[x]+=siz[v];
num=max(num,siz[v]);
}
num=max(num,tot-siz[x]);
if(num<mx){
mx=num;
rt=x;
}
}
int cmax(int x,int y){
if(dep[x]>dep[y]) return x;
else return y;
}
ll Xor(int x){
return dis[x]^val[a[x].fir]^val[a[x].sec];
}
void dfs(int x,int fa,int A,int B){
if(!exi[x]) return ;
if(A&&B){
if(col[x]!=col[A]&&col[x]!=col[B]) return ;
if(col[x]==col[A]) A=cmax(A,x);
if(col[x]==col[B]) B=cmax(B,x);
}
else if(A){
if(col[x]==col[A]) A=cmax(A,x);
else B=x;
}
else A=x;
rev[x]=1;
a[x]={A,B};
dis[x]=dis[fa]^w[x];
ans+=X[Xor(x)];
ans+=Y[dis[x]];
if(A){
ans+=Z[0][dis[x]^val[A]];
ans-=Z[col[A]][dis[x]^val[A]];
}
if(B){
ans+=Z[0][dis[x]^val[B]];
ans-=Z[col[B]][dis[x]^val[B]];
}
for(int i=head[x],v;i;i=e[i].nxt){
v=e[i].v;
if(v==fa||vis[v]) continue;
dfs(v,x,A,B);
}
}
void calc(int x,int fa){
if(!rev[x]) return ;
X[dis[x]]++;
Y[Xor(x)]++;
int A,B;
tie(A,B)=a[x];
if(A){
Z[0][dis[x]^val[A]]++;
Z[col[A]][dis[x]^val[A]]++;
}
if(B){
Z[0][dis[x]^val[B]]++;
Z[col[B]][dis[x]^val[B]]++;
}
for(int i=head[x],v;i;i=e[i].nxt){
v=e[i].v;
if(v==fa||vis[v]) continue;
calc(v,x);
}
}
void del(int x,int fa){
if(!rev[x]) return ;
rev[x]=0;
int A,B;
tie(A,B)=a[x];
if(A) Z[col[A]].clear();
if(B) Z[col[B]].clear();
a[x]={0,0};
for(int i=head[x],v;i;i=e[i].nxt){
v=e[i].v;
if(v==fa||vis[v]) continue;
del(v,x);
}
}
void getsz(int x,int fa){
siz[x]=1;nok[x]=0;
for(int i=head[x],v;i;i=e[i].nxt){
v=e[i].v;
if(v==fa||vis[v]) continue;
getsz(v,x);
siz[x]+=siz[v];
}
}
void sol(int x){
vis[x]=1;
B::sol(x);
dis[x]=0; X[0]++;
for(int i=head[x],v;i;i=e[i].nxt){
v=e[i].v;
if(vis[v]) continue;
dfs(v,x,0,0);
calc(v,x);
}
B::clear(x);
X.clear();Y.clear();Z[0].clear();
for(int i=head[x],v;i;i=e[i].nxt){
v=e[i].v;
if(vis[v]) continue;
del(v,x);
}
getsz(x,0);
for(int i=head[x],v;i;i=e[i].nxt){
v=e[i].v;
if(vis[v]) continue;
tot=siz[v];mx=n+1,rt=0;
dfz(v,x); sol(rt);
}
}
void solve(){
rt=0;mx=n+1,tot=n;
dfz(1,0); sol(rt);
printf("%lld\n",ans+n);
}
}
int main(){
// freopen("c.in","r",stdin);
// freopen("c.out","w",stdout);
n=read();
for(int i=2;i<=n;i++){
int x=read();
A::add(x,i);A::add(i,x);
}
for(int i=2;i<=n;i++){
int x=read();
B::add(x,i);B::add(i,x);
}
for(int i=1;i<=n;i++) w[i]=rnd();
A::solve();
return 0;
}

浙公网安备 33010602011771号