NOI.AC#486. astrology

参加了这活动总得写点东西吧.jpg

考虑求出恰好选出了\(i\)条不相交路径后第\(i+1\)条路径与之前的重合时的贡献。

\(f_i\)为在树中选出了\(i\)条不相交路径的方案数,那么在选出了\(i\)条路径后第\(i+1\)条相交的方案数就是\(g_i=f_i\times \dbinom{n+1}{2}-f_{i+1}\times (i+1)\)\(g_i\)对答案的贡献就是\(\frac{g_i\times i!\times (i+1)}{\dbinom{n+1}{2}^{i+1}}\)(因为在计算\(g_i\)的时候我们已经确定了第\(i+1\)条边,所以这里我们只要确定\(i\)条边的顺序即可)

考虑如何计算\(f_i\).记\(f_{u,i,0}\)表示在子树\(u\)中有\(i\)条完整的路径的方案数,\(f_{u,i,1}\)表示有\(i\)条完整的路径的同时还有一条向上延伸的路径的方案数。为了方便转移再设辅助状态\(f_{u,i,2}\)表示在\(u\)的子树中有两条路径在根\(u\)处汇合成一条路径的同时有\(i\)条路径的方案数。具体的转移可以看code,注意处理路径是一个点的情况以及路径的开始与结束。

#include<iostream>
#include<string.h>
#include<string>
#include<stdio.h>
#include<algorithm>
#include<vector>
#include<bitset>
#include<math.h>
#include<stack>
#include<queue>
#include<set>
#include<map>
using namespace std;
typedef long long ll;
typedef long double db;
typedef pair<int,int> pii;
const int N=100000;
const db pi=acos(-1.0);
#define lowbit(x) (x)&(-x)
#define sqr(x) (x)*(x)
#define rep(i,a,b) for (register int i=a;i<=b;i++)
#define per(i,a,b) for (register int i=a;i>=b;i--)
#define go(u,i) for (register int i=head[u];i;i=sq[i].nxt)
#define fir first
#define sec second
#define mp make_pair
#define pb push_back
#define maxd 998244353
#define eps 1e-8
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
    while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
    return x*f;
}

namespace My_Math{
	#define N 100000

	int fac[N+100],invfac[N+100];

	int add(int x,int y) {return x+y>=maxd?x+y-maxd:x+y;}
	int dec(int x,int y) {return x<y?x-y+maxd:x-y;}
	int mul(int x,int y) {return 1ll*x*y%maxd;}
	ll qpow(ll x,int y)
	{
		ll ans=1;
		while (y)
		{
			if (y&1) ans=mul(ans,x);
			x=mul(x,x);y>>=1;
		}
		return ans;
	}
	int inv(int x) {return qpow(x,maxd-2);}

	int C(int n,int m)
	{
		if ((n<m) || (n<0) || (m<0)) return 0;
		return mul(mul(fac[n],invfac[m]),invfac[n-m]);
	}

	int math_init()
	{
		fac[0]=invfac[0]=1;
		rep(i,1,N) fac[i]=mul(fac[i-1],i);
		invfac[N]=inv(fac[N]);
		per(i,N-1,1) invfac[i]=mul(invfac[i+1],i+1);
	}
	#undef N
}
using namespace My_Math;
struct node{int to,nxt;}sq[10010];
int all=0,head[5050];
int n,tmp[5050][3],g[5050][3],f[5050][5050][3],siz[5050];

void addedge(int u,int v)
{
	all++;sq[all].to=v;sq[all].nxt=head[u];head[u]=all;
}

void dfs(int u,int fu)
{
	go(u,i)
	{
		int v=sq[i].to;
		if (v!=fu) dfs(v,u);
	}
	rep(i,0,n+1) rep(j,0,2) tmp[i][j]=0;tmp[0][0]=1;
	go(u,i)
	{
		int v=sq[i].to;
		if (v==fu) continue;
		rep(j,0,siz[u]) rep(k,0,siz[v])
		{
			g[j+k][0]=add(g[j+k][0],mul(tmp[j][0],f[v][k][0]));
			g[j+k][1]=add(g[j+k][1],mul(tmp[j][1],f[v][k][0]));
			g[j+k][1]=add(g[j+k][1],mul(tmp[j][0],f[v][k][1]));
			g[j+k][2]=add(g[j+k][2],mul(tmp[j][2],f[v][k][0]));
			g[j+k+1][2]=add(g[j+k+1][2],mul(tmp[j][1],f[v][k][1]));
		}
		siz[u]+=siz[v];
		rep(j,0,siz[u]) rep(k,0,2) {tmp[j][k]=g[j][k];g[j][k]=0;}
	}
	rep(i,0,siz[u])
	{
		f[u][i][0]=add(f[u][i][0],add(tmp[i][0],tmp[i][2]));
		f[u][i+1][0]=add(f[u][i+1][0],add(tmp[i][0],tmp[i][1]));
		f[u][i][1]=add(f[u][i][1],add(tmp[i][0],tmp[i][1]));
	}
	siz[u]++;
}

int main()
{
	read();n=read();
	math_init();
	rep(i,1,n-1)
	{
		int u=read(),v=read();
		addedge(u,v);addedge(v,u);
	}
	dfs(1,0);
	ll ans=0,cn2=C(n+1,2),inv2=inv(cn2),sum=mul(inv2,inv2);
	rep(i,1,n)
	{
		int tmp=dec(mul(f[1][i][0],cn2),mul(f[1][i+1][0],i+1));
		ans=add(ans,mul(mul(tmp,i+1),mul(sum,fac[i])));
		sum=mul(sum,inv2);
	}
	printf("%d",ans);
	return 0;
}
posted @ 2020-02-18 23:43  EncodeTalker  阅读(101)  评论(0编辑  收藏  举报