[LOJ3124][CTS2019|CTSC2019]氪金手游:树形DP+概率DP+容斥原理

分析

首先容易得出这样一个事实,在若干物品中最先被选出的是编号为\(i\)的物品的概率为\(\frac{W_i}{\sum_{j=1}^{cnt}W_j}\)

假设树是一棵外向树,即父亲比儿子先选(一个点比它的子树中的所有其他的点先选),我们可以令\(f(i,j)\)表示以\(i\)为根的子树,子树内的总权值为\(j\),子树内的选取顺序合法的概率,转移类似树上分组背包。

那么我们现在需要考虑如何处理儿子比父亲先选的情况,其实可以直接容斥,减去父亲比儿子先选的概率就好了,注意这样的子树不要统计到\(f(i,j)\)的第二维中。

代码

#include <bits/stdc++.h>

#define rin(i,a,b) for(int i=(a);i<=(b);++i)
#define irin(i,a,b) for(int i=(a);i>=(b);--i)
#define trav(i,a) for(int i=head[a];i;i=e[i].nxt)
#define Size(a) (int) a.size()
#define pb push_back
#define mkpr std::make_pair
#define fi first
#define se second
#define lowbit(a) ((a)&(-(a)))
typedef long long LL;

using std::cerr;
using std::endl;

inline int read(){
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}

const int MAXN=1005;
const int MOD=998244353;

int n,ecnt,head[MAXN];
int p[MAXN][4],siz[MAXN];
int f[MAXN][MAXN*3],g[MAXN*3];
int inv[MAXN*3];

struct Edge{
	int to,nxt;
}e[MAXN<<1];

inline void add_edge(int bg,int ed){
	++ecnt;
	e[ecnt].to=ed;
	e[ecnt].nxt=head[bg];
	head[bg]=ecnt;
}

inline int qpow(int x,int y){
	int ret=1,tt=x%MOD;
	while(y){
		if(y&1)ret=1ll*ret*tt%MOD;
		tt=1ll*tt*tt%MOD;
		y>>=1;
	}
	return ret;
} 

void dfs(int x,int pre){
	f[x][0]=1; 
	trav(i,x){
		int ver=e[i].to;
		if(ver==pre)continue;
		dfs(ver,x);
		memset(g,0,sizeof g);
		if(i&1){
			irin(j,siz[x]*3,0)rin(k,1,siz[ver]*3)
				g[j+k]=(g[j+k]+1ll*f[x][j]*f[ver][k])%MOD;
		}
		else{
			int sum=0;
			rin(j,1,siz[ver]*3)sum=(sum+f[ver][j])%MOD;
			irin(j,siz[x]*3,0){
				g[j]=(g[j]+1ll*f[x][j]*sum)%MOD;
				rin(k,1,siz[ver]*3)g[j+k]=(g[j+k]-1ll*f[x][j]*f[ver][k]%MOD+MOD)%MOD;
			}
		}
		memcpy(f[x],g,sizeof g);
		siz[x]+=siz[ver];
	}
	memset(g,0,sizeof g);
	rin(i,0,siz[x]*3)rin(j,1,3)
		g[i+j]=(g[i+j]+1ll*f[x][i]*p[x][j]%MOD*j%MOD*inv[i+j])%MOD;
	memcpy(f[x],g,sizeof g);
	++siz[x];
}

void init(int n){
	inv[1]=1;
	rin(i,2,n)inv[i]=(-1ll*(MOD/i)*inv[MOD%i]%MOD+MOD)%MOD;
}

int main(){
	n=read();init(n*3);
	rin(i,1,n){
		int sum=0;
		rin(j,1,3)sum+=p[i][j]=read();
		int invsum=qpow(sum,MOD-2);
		rin(j,1,3)p[i][j]=1ll*p[i][j]*invsum%MOD;
	}
	rin(i,2,n){
		int u=read(),v=read();
		add_edge(u,v);
		add_edge(v,u);
	}
	dfs(1,0);
	int ans=0;
	rin(i,1,n*3)ans=(ans+f[1][i])%MOD;
	printf("%d\n",ans);
	return 0;
} 

posted on 2019-05-22 08:13  ErkkiErkko  阅读(90)  评论(0编辑  收藏

统计