[CTS2019]氪金手游

https://www.luogu.org/problemnew/show/P5405

题解

首先考虑一条链的情况。

\(O->O->O->O->O\)

比如说这样一条链。

每个元素应当是这个元素到结尾这条子链中第一个被抽到的,这个概率手算一下发现它是

\[\frac{p_i}{\sum_{j=i}^np_j} \]

所以答案其实就是所有i的积。

\(O->O<-O->O->O\)

如果出现这样的情况怎么办?

考虑如果把这条边断掉,对两条链分别算答案乘起来,这样会算出来的不合法的情况是第二个在第三个前面被抽到的情况,所以我们再减去把那条边再反过来的答案就好了。

对于多反边,可以想到容斥,答案为考虑0条边的-考虑奇数条边的+考虑偶数条边的(这里的考虑是指反向)。

根据十二省联考的经验,链上的情况是可以推广到树上的。

实现的时候不需要枚举子集,直接在树形\(dp\)中把反边的边权记录成\(-1\)就好了,

突然发现我考场好像读错题了。

代码

#include<bits/stdc++.h>
#define N 1009
using namespace std;
typedef long long ll;
const int mod=998244353;
ll dp[N][N*3],g[N*3],ni[N*3],a[N][3];
int n,size[N],head[N],tot;
inline ll rd(){
  ll x=0;char c=getchar();bool f=0;
  while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
  while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
  return f?-x:x;
}
inline ll power(ll x,ll y){
  ll ans=1;
  while(y){
    if(y&1)ans=ans*x%mod;
    x=x*x%mod;
    y>>=1;
  }
  return ans;
}
struct edge{
  int n,to,l;
}e[N<<1];
inline void add(int u,int v){
  e[++tot].n=head[u];e[tot].to=v;head[u]=tot;e[tot].l=1;
  e[++tot].n=head[v];e[tot].to=u;head[v]=tot;e[tot].l=mod-1;
}
inline void MOD(ll &x){x=x>=mod?x-mod:x;}
inline void prework(int n){
  for(int i=1;i<=n;++i)ni[i]=power(i,mod-2);
  return;
}
void dfs(int u,int fa){
  dp[u][0]=1;
  for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa){
    int v=e[i].to;
    dfs(v,u);
    memset(g,0,sizeof(g));
    for(int j=0;j<=size[u];++j)
      for(int k=0;k<=size[v];++k){
        MOD(g[j+k]+=dp[u][j]*dp[v][k]%mod*e[i].l%mod);
        if(e[i].l!=1)MOD(g[j]+=dp[u][j]*dp[v][k]%mod);
      }
    memcpy(dp[u],g,sizeof(dp[u]));
    size[u]+=size[v];
  }
  memset(g,0,sizeof(g));
  for(int i=0;i<=size[u];++i)
    for(int j=1;j<=3;++j){
      MOD(g[i+j]+=dp[u][i]*ni[i+j]%mod*j%mod*a[u][j-1]%mod);
    } 
  memcpy(dp[u],g,sizeof(dp[u]));
  size[u]+=3;
}
int main(){
  n=rd();
  for(int i=1;i<=n;++i){
    ll s=0;
    a[i][0]=rd();a[i][1]=rd();a[i][2]=rd();
    s=power(a[i][0]+a[i][1]+a[i][2],mod-2);
    a[i][0]=a[i][0]*s%mod;
    a[i][1]=a[i][1]*s%mod;
    a[i][2]=a[i][2]*s%mod;
  }
  int x,y;
  for(int i=1;i<n;++i){
    x=rd();
    y=rd();
    add(x,y);
  }
  prework(3*n);
  dfs(1,0);
  ll ans=0;
  for(int i=1;i<=n*3;++i)MOD(ans+=dp[1][i]);
  cout<<ans;
  return 0;
}
posted @ 2019-05-23 10:16  comld  阅读(249)  评论(0编辑  收藏  举报