[hdu6647]Bracket Sequences on Tree 解题报告

oj:https://gxyzoj.com/d/hzoj/p/3575

因为自己的脑残原因,调了8个小时啊!!!

切入正题

Part 1

假定1为根,可以发现,如果u的两棵子树同构,则他们遍历的顺序不影响答案

所以,就可以将子树按哈希值分类,这道题就变成了一个可重复组合问题,

\(f_i\)表示以1为根时i的方案数,\(a_i\)表示某一种哈希值的个数,\(cnt_u\)表示u的子节点数,式子:

\[f_u=\dfrac{cnt_u!}{\prod_{i=1} a_i} \times \prod_{v\in son(u)} f_v \]

Part 2

若以两个节点i,j为根,两棵树同构,则方案不能重复计算,所以考虑计算以每个节点为根的哈希值

因为n很大,所以不能暴力,考虑树形dp

分为子树外的贡献和子树内的贡献,子树外的显然是父亲为根的贡献\(rt_{fa}\)减去该点所在子树的贡献,子树内的则可以直接dfs求解

所以:rt[u]=h[u]+get_hash(rt[fa]-get_hash(h[u]));

Part 3

接着考虑,因为n很大,所以不能像第一部分一样暴力跑所有点,考虑树形dp

思路和第二部分很像,分为子树内和子树外考虑,记该点为u,儿子为v,节点i的答案为\(ans_i\)

先考虑子树外

显然,u少了子树v,所以要乘\(f_v^{-1}\),而且u的子树少了1,所以要乘\(cnt_u^{-1}\),于此同时,u子树中与v同构的子树会-1,所以要乘\(a_v\)

\(tmp=ans_u \times f_v^{-1} \times cnt_u^{-1} \times a_v\)

再考虑子树内

首先,u对v的贡献为tmp

其次,v的子树多了1个,所以乘\(cnt_v+1\)

第三,与u相同的子树多了一个,所以乘\((a_u+1)^{-1}\)

故有:

\[ans_v=tmp\times (cnt_v+1)\times (a_u+1)^{-1} \]

代码:

#include<bits/stdc++.h>
#define ull unsigned long long
#define ll long long
using namespace std;
const int p=998244353,N=1e5;
int T,head[100005],edgenum;
struct edge{
	int to,nxt;
}e[200005];
void add_edge(int  u,int v)
{
	e[++edgenum].nxt=head[u];
	e[edgenum].to=v;
	head[u]=edgenum;
}
ull h[100005],mask,rt[100005];
ll qpow(ll x,int y)
{
	ll res=1;
	while(y)
	{
		if(y&1) res=res*x%p;
		x=x*x%p;
		y>>=1;
	}
	return res;
}
ull get_hash(ull x)
{
	x^=mask;
	x^=x<<13;
	x^=x>>7;
	x^=x<<17;
	x^=mask;
	return x;
}
ll fac[100005],inv[100005],f[100005];
map<ull,int>mp[100005];
ll d[100005];
void dfs(int u,int fa)
{
	f[u]=1;
	h[u]=1;
	map<ull,bool> mp1;
	mp1.clear();
	for(int i=head[u];i;i=e[i].nxt)
	{
		int v=e[i].to;
		if(v==fa) continue;
		dfs(v,u);
		h[u]+=get_hash(h[v]);
		f[u]=f[u]*f[v]%p;
		mp[u][h[v]]+=1;
	}
	if(u!=1)
	f[u]=f[u]*fac[d[u]-1]%p;
	else
	f[u]=f[u]*fac[d[u]]%p;
	for(int i=head[u];i;i=e[i].nxt)
	{
		int v=e[i].to;
		if(v==fa||mp1[h[v]]) continue;
		f[u]=f[u]*inv[mp[u][h[v]]]%p;
		mp1[h[v]]=1;
	}
}
map<ull,bool> tr;
void dfs2(int u,int fa)
{
	for(int i=head[u];i;i=e[i].nxt)
	{
		int v=e[i].to;
		if(v==fa) continue;
		rt[v]=h[v]+get_hash(rt[u]-get_hash(h[v]));
		ll tmp=f[u]*mp[u][h[v]]%p*qpow(f[v],p-2)%p*qpow(d[u],p-2)%p;
		ull ha=rt[u]-get_hash(h[v]);
		mp[v][ha]++;
		f[v]=f[v]*tmp%p*d[v]%p*qpow(mp[v][ha],p-2)%p;
		dfs2(v,u);
	}
}
void init(int n)
{
	for(int i=1;i<=n;i++)
	{
		head[i]=rt[i]=h[i]=d[i]=f[i]=0;
		mp[i].clear();
	}
	tr.clear();
	edgenum=0;
}
int main()
{
	srand((unsigned)time(0));
	mask=1ll*rand()*rand();
	fac[0]=1;
	for(int i=1;i<=N;i++)
	{
		fac[i]=fac[i-1]*i%p;
	}
	inv[N]=qpow(fac[N],p-2);
	for(int i=N-1;i>=0;i--)
	{
		inv[i]=inv[i+1]*(i+1)%p;
	}
	scanf("%d",&T);
	while(T--)
	{
		int n;
		scanf("%d",&n);
		init(n);
		for(int i=1;i<n;i++)
		{
			int u,v;
			scanf("%d%d",&u,&v);
			d[v]++,d[u]++;
			add_edge(u,v);
			add_edge(v,u);
		}
		dfs(1,0);
		rt[1]=h[1];
		dfs2(1,0);
		ll ans=0;
		for(int i=1;i<=n;i++)
		{
		//	printf("%llu %lld\n",rt[i],f[i]);
			if(!tr[rt[i]])
			{
				ans=(ans+f[i])%p;
				tr[rt[i]]=1;
			}
		}
		printf("%lld\n",ans);
	}
	return 0;
}
posted @ 2024-03-31 21:02  wangsiqi2010916  阅读(83)  评论(0)    收藏  举报