【洛谷P5405】氪金手游
题目
题目链接:https://www.luogu.com.cn/problem/P5405
小刘同学是一个喜欢氪金手游的男孩子。
他最近迷上了一个新游戏,游戏的内容就是不断地抽卡。现在已知:
- 卡池里总共有 \(N\) 种卡,第 \(i\) 种卡有一个权值 \(W_i\),小刘同学不知道 \(W_i\) 具体的值是什么。但是他通过和网友交流,他了解到 \(W_i\) 服从一个分布。
- 具体地,对每个 \(i\),小刘了解到三个参数 \(p_{i,1},p_{i,2},p_{i,3}\),\(W_i\) 将会以 \(p_{i,j}\) 的概率取值为 \(j\),保证 \(p_{i,1}+p_{i,2}+p_{i,3}=1\)。
小刘开始玩游戏了,他每次会氪一元钱来抽一张卡,其中抽到卡 \(i\) 的概率为:
小刘会不停地抽卡,直到他手里集齐了全部 \(N\) 种卡。
抽卡结束之后,服务器记录下来了小刘第一次得到每张卡的时间 \(T_i\)。游戏公司在这里设置了一个彩蛋:公司准备了 \(N-1\) 个二元组 \((u_i,v_i)\),如果对任意的 \(i\),成立 \(T_{u_i}<T_{v_i}\),那么游戏公司就会认为小刘是极其幸运的,从而送给他一个橱柜的手办作为幸运大奖。
游戏公司为了降低获奖概率,它准备的这些 \((u_i,v_i)\) 满足这样一个性质:对于任意的 \(\varnothing\ne S\subsetneq\{1,2,\ldots,N\}\),总能找到 \((u_i,v_i)\) 满足:\(u_i\in S,v_i\notin S\) 或者 \(u_i\notin S,v_i\in S\)。
请你求出小刘同学能够得到幸运大奖的概率,可以保证结果是一个有理数,请输出它对 \(998244353\) 取模的结果。
思路
题目中给出了关于所有二元组的性质。如果我们把二元组 \((u_i,v_i)\) 看作一条从 \(u_i\) 向 \(v_i\) 的有向边,那么其实等价于连边后会形成一棵树(无视边的方向)。
先考虑如果这棵树是一棵外向树怎么办。不妨设 \(1\) 为根,对于一个点 \(x\),它被抽到的时间必须小于它子树内的点被抽到的时间,它子树外的点与它无关。记 \(sum[x]\) 表示 \(x\) 子树内所有点的 \(w\) 之和,这种情况的概率为
后面那一坨东西等比数列求和一下,发现原式等于 \(\frac{w_x}{sum[x]}\)。
设 \(f[x][i]\) 表示点 \(x\) 的子树内,所有点的 \(w\) 之和为 \(i\) 的情况下,满足子树内所有条件的期望。
那么合并 \(x\) 和它的一棵子树 \(y\) 时只需要树上背包枚举到子树大小就可以做到 \(O(n^2)\) 了。
当这棵树不是一棵外向树时,考虑把反向边的贡献容斥掉。
那么合并 \(x\) 的一个子树 \(y\) 时,如果 \(x\) 与 \(y\) 的连边是 \(y\to x\) 的,如果算 \(y\) 子树的贡献,那么就需要加上 \(y\) 子树内 \(w\) 之和,容斥系数乘上 \(-1\);如果不算 \(y\) 子树的贡献,那么就不加 \(y\) 子树内 \(w\) 之和,容斥系数也不用乘。
预处理逆元可以做到 \(O(n^2)\)。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1010,MOD=998244353;
int n,tot,a[N][4],head[N],siz[N];
ll ans,inv[N],f[N][N*3],g[N*3];
struct edge
{
int next,to;
}e[N*2];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
ll fpow(ll x,ll k)
{
ll ans=1;
for (;k;k>>=1,x=x*x%MOD)
if (k&1) ans=ans*x%MOD;
return ans;
}
void dfs(int x,int fa)
{
ll res=fpow(a[x][1]+a[x][2]+a[x][3],MOD-2);
f[x][1]=1LL*a[x][1]*res%MOD;
f[x][2]=2LL*a[x][2]*res%MOD;
f[x][3]=3LL*a[x][3]*res%MOD;
siz[x]=1;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa)
{
dfs(v,x);
for (int j=1;j<=siz[x]*3;j++)
g[j]=f[x][j],f[x][j]=0;
for (int j=1;j<=siz[x]*3;j++)
for (int k=1;k<=siz[v]*3;k++)
{
f[x][j+k]=(f[x][j+k]+((i&1)?1LL:-1LL)*g[j]*f[v][k])%MOD;
if (!(i&1)) f[x][j]=(f[x][j]+1LL*g[j]*f[v][k])%MOD;
}
siz[x]+=siz[v];
}
}
for (int i=1;i<=siz[x]*3;i++)
f[x][i]=f[x][i]*inv[i]%MOD;
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d",&n);
for (int i=1;i<=n;i++)
scanf("%d%d%d",&a[i][1],&a[i][2],&a[i][3]);
for (int i=1;i<=n*3;i++) inv[i]=fpow(i,MOD-2);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
dfs(1,0);
for (int i=1;i<=n*3;i++)
ans=(ans+f[1][i])%MOD;
cout<<(ans+MOD)%MOD;
return 0;
}