[LOJ3124][CTS2019|CTSC2019]氪金手游:树形DP+概率DP+容斥原理

分析

首先容易得出这样一个事实,在若干物品中最先被选出的是编号为\(i\)的物品的概率为\(\frac{W_i}{\sum_{j=1}^{cnt}W_j}\)

假设树是一棵外向树,即父亲比儿子先选(一个点比它的子树中的所有其他的点先选),我们可以令\(f(i,j)\)表示以\(i\)为根的子树,子树内的总权值为\(j\),子树内的选取顺序合法的概率,转移类似树上分组背包。

那么我们现在需要考虑如何处理儿子比父亲先选的情况,其实可以直接容斥,减去父亲比儿子先选的概率就好了,注意这样的子树不要统计到\(f(i,j)\)的第二维中。

代码

#include <bits/stdc++.h>

#define rin(i,a,b) for(int i=(a);i<=(b);++i)
#define irin(i,a,b) for(int i=(a);i>=(b);--i)
#define trav(i,a) for(int i=head[a];i;i=e[i].nxt)
#define Size(a) (int) a.size()
#define pb push_back
#define mkpr std::make_pair
#define fi first
#define se second
#define lowbit(a) ((a)&(-(a)))
typedef long long LL;

using std::cerr;
using std::endl;

inline int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

const int MAXN=1005;
const int MOD=998244353;

int n,ecnt,head[MAXN];
int p[MAXN][4],siz[MAXN];
int f[MAXN][MAXN*3],g[MAXN*3];
int inv[MAXN*3];

struct Edge{
    int to,nxt;
}e[MAXN<<1];

inline void add_edge(int bg,int ed){
    ++ecnt;
    e[ecnt].to=ed;
    e[ecnt].nxt=head[bg];
    head[bg]=ecnt;
}

inline int qpow(int x,int y){
    int ret=1,tt=x%MOD;
    while(y){
        if(y&1)ret=1ll*ret*tt%MOD;
        tt=1ll*tt*tt%MOD;
        y>>=1;
    }
    return ret;
} 

void dfs(int x,int pre){
    f[x][0]=1; 
    trav(i,x){
        int ver=e[i].to;
        if(ver==pre)continue;
        dfs(ver,x);
        memset(g,0,sizeof g);
        if(i&1){
            irin(j,siz[x]*3,0)rin(k,1,siz[ver]*3)
                g[j+k]=(g[j+k]+1ll*f[x][j]*f[ver][k])%MOD;
        }
        else{
            int sum=0;
            rin(j,1,siz[ver]*3)sum=(sum+f[ver][j])%MOD;
            irin(j,siz[x]*3,0){
                g[j]=(g[j]+1ll*f[x][j]*sum)%MOD;
                rin(k,1,siz[ver]*3)g[j+k]=(g[j+k]-1ll*f[x][j]*f[ver][k]%MOD+MOD)%MOD;
            }
        }
        memcpy(f[x],g,sizeof g);
        siz[x]+=siz[ver];
    }
    memset(g,0,sizeof g);
    rin(i,0,siz[x]*3)rin(j,1,3)
        g[i+j]=(g[i+j]+1ll*f[x][i]*p[x][j]%MOD*j%MOD*inv[i+j])%MOD;
    memcpy(f[x],g,sizeof g);
    ++siz[x];
}

void init(int n){
    inv[1]=1;
    rin(i,2,n)inv[i]=(-1ll*(MOD/i)*inv[MOD%i]%MOD+MOD)%MOD;
}

int main(){
    n=read();init(n*3);
    rin(i,1,n){
        int sum=0;
        rin(j,1,3)sum+=p[i][j]=read();
        int invsum=qpow(sum,MOD-2);
        rin(j,1,3)p[i][j]=1ll*p[i][j]*invsum%MOD;
    }
    rin(i,2,n){
        int u=read(),v=read();
        add_edge(u,v);
        add_edge(v,u);
    }
    dfs(1,0);
    int ans=0;
    rin(i,1,n*3)ans=(ans+f[1][i])%MOD;
    printf("%d\n",ans);
    return 0;
} 

posted on 2019-05-22 08:13 ErkkiErkko 阅读(...) 评论(...) 编辑 收藏

统计