2022 CCPC 广州 I-Infection

2022 CCPC 广州 I-Infection

Statement

有一棵树, 一开始有一个点被感染, 每个点初始被感染的概率为\(\frac{a_i}{\sum_{i=1}^n a_i}\), 此外每个点被相邻的点传染(相邻点一定是被感染点)的概率为\(\frac{b_i}{c_i}\), 求\(1\cdots n\)个被感染点的概率.

Solution

我们设\(f_{i,j}\)表示以为\(i\)根的子树中选择\(j\)个点(包含根, 不包含关键节点)的概率, \(g_{i,j}\)表示以\(i\)为根的子树中选择\(j\)的点(包含根, 包含关键节点的概率).

我们考虑\(v\in\text{son}_u\)合并计算\(f,g\)的贡献则转移方程如下

  • \(f_{u,k}'=\sum_{i+j=k}f_{u,i}\times f_{v,j}\).
  • \(g_{u,k}'=\sum_{i+j=k}(g_{u,i}\times f_{v,j}+f_{u,i}\times g_{v,j})\).
  • Initialization: \(f_{u,1}=\frac{b_i}{c_i},g_{u,1}=\frac{a_u}{\sum_{i=1}^na_i}\).
  • Final: \(f_{u,0}=1-\frac{b_u}{c_u}\)

时间复杂度\(O(n^2)\): Simple Proof: 每个点对仅会在LCA产生贡献, 总点对数为\(O(n^2)\)则时间复杂度为\(O(n^2)\).

Code

# define Fast_IO std::ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
# include "functional"
# include "algorithm"
# include "iostream"
# include "cstdlib"
# include "cstring"
# include "cstdio"
# include "vector"
# include "bitset"
# include "cassert"
# include "random"
# include "queue"
# include "cmath"
# include "ctime"
# include "map"
# include "set"
# define ll long long
# define ld long double
# define rep1(i,a,b) for(ll i=(a);i<=(b);i++)
# define rep2(i,a,b) for(ll i=(b);i>=(a);i--)
# define pii pair<int,int>
# define pll pair<ll,ll>
# define ph push_back
# define pb pop_back
# define eb emplace_back
# define vi vector<int>
# define vll vector<ll>
# define vpi vector<pii >
# define vpll vector<pll >
# define ri(x) scanf("%d",&x)
# define rf(x) scanf("%f",&x)
# define rl(x) scanf("%lld",&x)
# define rd(x) scanf("%lf",&x)
# define rs(s) scanf("%s",s+1)
# define wi(x) printf("%d",x)
# define wl(x) printf("%lld",x)
# define ws(s) printf("%s",s+1)
# define all(v) v.begin(),v.end()
# define fi first
# define se second
# define repauto(Name,v) for(auto Name:v)
# define Endl "\n"
# define ENDL putchar('\n')
using namespace std;
template<class I> inline I GCD(I A,I B){return B?GCD(B,A%B):A;}
template<class I> inline I LCM(I A,I B){return A/GCD(A,B)*B;}
template<class I> I Sqrt(I N){
	I sqrtN=sqrt(N)-1;
	while(sqrtN+1<=N/(sqrtN+1))sqrtN++;
	return sqrtN;
}
long long Pow(long long X,long long Y,__int128 Mod1=998244353){
	static __int128 Ans; Ans=1;
	for(;Y;Y>>=1,X=(__int128)X*X%Mod1) if(Y&1) Ans=Ans*X%Mod1;
	return Ans;
}
mt19937_64 e(time(NULL));

const int maxm=2e3+10;
const int Mod=1e9+7;

int N;
long long A[maxm],B[maxm],C[maxm],P1[maxm],P2[maxm],Sum,Root;
int DP1[maxm][maxm],DP2[maxm][maxm],Size[maxm],Tem1[maxm],Tem2[maxm];
int Ans[maxm];
vector<int> Edge[maxm];

void DFS(int Now,int Fa=0){
	int i,j,Max;
	DP1[Now][1]=P1[Now];	// None->Tem1
	DP2[Now][1]=P2[Now];	// One->Tem2
	++Size[Now];
	for(auto &To:Edge[Now]){
		if(To==Fa) continue;
		DFS(To,Now);
		Max=Size[Now]+Size[To];
		for(i=0;i<=Max;++i) Tem1[i]=Tem2[i]=0;
		for(i=1;i<=Size[To];++i){
			Ans[i]=(Ans[i]+(long long)DP2[To][i]*(1-P1[Now]+Mod))%Mod;
		}
		for(i=0;i<=Size[Now];++i){
			for(j=0;j<=Size[To];++j){
				Tem1[i+j]=(Tem1[i+j]+(long long)DP1[Now][i]*DP1[To][j])%Mod;
				Tem2[i+j]=(Tem2[i+j]+(long long)DP1[Now][i]*DP2[To][j])%Mod;
				Tem2[i+j]=(Tem2[i+j]+(long long)DP2[Now][i]*DP1[To][j])%Mod;
			}
		}
		for(i=0;i<=Max;++i) DP1[Now][i]=Tem1[i],DP2[Now][i]=Tem2[i];
		Size[Now]+=Size[To];
	}
	DP1[Now][0]=(1-P1[Now]+Mod)%Mod;
	return;
}

int main(){
# ifdef LH_Frank
    freopen("1.in","r",stdin);
	freopen("1.out","w",stdout);
# endif
	static int i,j,U,V;
	scanf("%d",&N);
	for(i=1;i<N;++i){
		scanf("%d%d",&U,&V);
		Edge[U].push_back(V);
		Edge[V].push_back(U);
	}
	for(i=1;i<=N;++i){
		scanf("%lld%lld%lld",&A[i],&B[i],&C[i]);
		Sum+=A[i];
		P1[i]=B[i]*Pow(C[i],Mod-2,Mod)%Mod;
	}
	for(i=1;i<=N;++i) P2[i]=A[i]*Pow(Sum,Mod-2,Mod)%Mod;
	DFS(1);
	for(i=1;i<=N;++i) Ans[i]=(Ans[i]+DP2[1][i])%Mod;
	for(i=1;i<=N;++i) printf("%d\n",Ans[i]);
//	printf("%d\n",P2[3]);
	return 0;
}
posted @ 2022-11-14 17:32  FJ-Frank  阅读(363)  评论(0)    收藏  举报