[LOJ3124][CTS2019|CTSC2019]氪金手游:树形DP+概率DP+容斥原理
分析
首先容易得出这样一个事实,在若干物品中最先被选出的是编号为\(i\)的物品的概率为\(\frac{W_i}{\sum_{j=1}^{cnt}W_j}\)。
假设树是一棵外向树,即父亲比儿子先选(一个点比它的子树中的所有其他的点先选),我们可以令\(f(i,j)\)表示以\(i\)为根的子树,子树内的总权值为\(j\),子树内的选取顺序合法的概率,转移类似树上分组背包。
那么我们现在需要考虑如何处理儿子比父亲先选的情况,其实可以直接容斥,减去父亲比儿子先选的概率就好了,注意这样的子树不要统计到\(f(i,j)\)的第二维中。
代码
#include <bits/stdc++.h>
#define rin(i,a,b) for(int i=(a);i<=(b);++i)
#define irin(i,a,b) for(int i=(a);i>=(b);--i)
#define trav(i,a) for(int i=head[a];i;i=e[i].nxt)
#define Size(a) (int) a.size()
#define pb push_back
#define mkpr std::make_pair
#define fi first
#define se second
#define lowbit(a) ((a)&(-(a)))
typedef long long LL;
using std::cerr;
using std::endl;
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
const int MAXN=1005;
const int MOD=998244353;
int n,ecnt,head[MAXN];
int p[MAXN][4],siz[MAXN];
int f[MAXN][MAXN*3],g[MAXN*3];
int inv[MAXN*3];
struct Edge{
int to,nxt;
}e[MAXN<<1];
inline void add_edge(int bg,int ed){
++ecnt;
e[ecnt].to=ed;
e[ecnt].nxt=head[bg];
head[bg]=ecnt;
}
inline int qpow(int x,int y){
int ret=1,tt=x%MOD;
while(y){
if(y&1)ret=1ll*ret*tt%MOD;
tt=1ll*tt*tt%MOD;
y>>=1;
}
return ret;
}
void dfs(int x,int pre){
f[x][0]=1;
trav(i,x){
int ver=e[i].to;
if(ver==pre)continue;
dfs(ver,x);
memset(g,0,sizeof g);
if(i&1){
irin(j,siz[x]*3,0)rin(k,1,siz[ver]*3)
g[j+k]=(g[j+k]+1ll*f[x][j]*f[ver][k])%MOD;
}
else{
int sum=0;
rin(j,1,siz[ver]*3)sum=(sum+f[ver][j])%MOD;
irin(j,siz[x]*3,0){
g[j]=(g[j]+1ll*f[x][j]*sum)%MOD;
rin(k,1,siz[ver]*3)g[j+k]=(g[j+k]-1ll*f[x][j]*f[ver][k]%MOD+MOD)%MOD;
}
}
memcpy(f[x],g,sizeof g);
siz[x]+=siz[ver];
}
memset(g,0,sizeof g);
rin(i,0,siz[x]*3)rin(j,1,3)
g[i+j]=(g[i+j]+1ll*f[x][i]*p[x][j]%MOD*j%MOD*inv[i+j])%MOD;
memcpy(f[x],g,sizeof g);
++siz[x];
}
void init(int n){
inv[1]=1;
rin(i,2,n)inv[i]=(-1ll*(MOD/i)*inv[MOD%i]%MOD+MOD)%MOD;
}
int main(){
n=read();init(n*3);
rin(i,1,n){
int sum=0;
rin(j,1,3)sum+=p[i][j]=read();
int invsum=qpow(sum,MOD-2);
rin(j,1,3)p[i][j]=1ll*p[i][j]*invsum%MOD;
}
rin(i,2,n){
int u=read(),v=read();
add_edge(u,v);
add_edge(v,u);
}
dfs(1,0);
int ans=0;
rin(i,1,n*3)ans=(ans+f[1][i])%MOD;
printf("%d\n",ans);
return 0;
}
posted on 2019-05-22 08:13 ErkkiErkko 阅读(206) 评论(0) 编辑 收藏 举报