题解:CF2071E LeaFall
分类讨论部分参考了现有的一篇题解(怎么想不到呢),自己推了式子。
一个点成为叶子的充要条件是它自己没有被删除,且它的邻居恰好有一个没有被删除。
对下文变量的一些声明:
- \(p_i\) 表示 \(i\) 脱落的概率,并令 \(q_i=\frac{1-p_i}{p_i}\)。
- \(a_u\) 表示 \(u\) 的所有邻居都脱落的概率,\(a_u=\Pi p_v\)。
- \(b_u\) 表示 \(u\) 恰有一个邻居未脱落,其他邻居都脱落的概率,\(b_u=\sum _v a_u q_v\)。
- \(d_u\) 表示 \(u\) 恰有一个邻居未脱落,其他邻居都脱落,且 \(u\) 本身未脱落的概率,\(d_u=(1-p_u)b_u\)。
- \(s_u\) 表示固定根节点时, \(u\) 及其直属儿子的 \(d\) 值之和,\(s_u=d_u + \sum_{v \in son(u)} d_v\)。
将点对分类,分别计算贡献:
对于 \(dis(u,v)=1\) :
此时两点互为邻居,枚举树上每一条边,贡献为 \(a_u a_v q_u q_v\)。这样,第一类贡献就为每条树边的贡献之和。复杂度 \(O(n)\)。
对于 \(dis(u,v)=2\):
此时 \(u\) 和 \(v\) 有一个公共邻居 \(z\),枚举 \(z\),对每个 \(z\) 统计它邻居两两贡献。
下式前半部分对应 \(z\) 未脱落的情况,此时 \(u,v\) 的其他邻居必须全部脱落,且 \(u,v\) 本身不脱落;后半部分对应 \(z\) 脱落的情况,\(u\) 和 \(v\) 都恰好剩下一个未脱落的邻居,且 \(u,v\) 本身不脱落。对于某个 \(z\) 的贡献:
这样子复杂度是 \(O(deg_z^2)\) 的,但是由 \(2 \times \sum_{i<j} x_ix_j = (\sum x_i)^2-\sum x_i^2\) 可以非常容易地优化为 \(O(deg_z)\)。我是另外开了一个函数算这种形式的式子。
最后把所有 \(f(z)\) 累加,就得到了第二类贡献。复杂度 \(O(n)\)。
对于 \(dis(u,v)>2\)
计算第三类贡献:
这样不太好处理,转化为对每个点求贡献,并转化为总贡献减去不合法(距离太小)的贡献,设 \(sum=\sum_{i=1}^{n} d_i\)。
考虑如何 \(O(deg_u)\) 获取里面的和式:
第一项即 \(u\) 父亲的父亲,第二项包括 \(u\) 的父亲以及 \(u\) 的父亲的儿子们(包括 \(u\)),第三项为 \(u\) 的儿子与孙子(即儿子的儿子)。
复杂度 \(O(n)\)。
上面的式子是写完代码之后重新推的,虽然检查过但是没准还是有错,欢迎在评论区捉虫。
#include<bits/stdc++.h>
#define Spc putchar(' ')
#define End putchar('\n')
#define For(i,il,ir) for(int i=(il);i<=(ir);++i)
#define Fr(i,il,ir) for(int i=(il);i<(ir);++i)
#define Forr(i,ir,il) for(int i=(ir);i>=(il);--i)
#define ForE(u) for(int i=head[u];~i;i=e[i].nxt)
#define fi first
#define se second
#define mk make_pair
#define pb emplace_back
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
namespace _TvT_{
template<typename T>
inline void rd(T& x){
bool f=0;x=0;char ch=getchar();
while(ch<'0'||ch>'9'){ if(ch=='-') f=1; ch=getchar(); }
while(ch>='0'&&ch<='9') x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f) x=-x;
}
template<typename T,typename... Args>
void rd(T& first,Args&... args){ rd(first),rd(args...); }
int write_num[50];
template<typename T>
inline void write(T x){
int len=0;
if(x<0) putchar('-'),x=-x;
do write_num[len++]=x%10; while(x/=10);
while(len--) putchar(write_num[len]+'0');
}
template<typename T,typename... Args>
void write(T first,Args... args){ write(first),Spc,write(args...); }
}using namespace _TvT_;
const int maxn=1e5+10;
const ll mod=998244353;
const ll inv2=499122177;
void qadd(ll &x,ll y){ x=(x+y>=mod)?(x+y-mod):(x+y); }
ll qp(ll x,ll b){ ll res=1ll; x%=mod; for(;b;x=x*x%mod,b>>=1) if(b&1) res=res*x%mod; return res; }
ll inv(ll x){ return qp(x,mod-2); }
int n;
vector<int> ve[maxn];
ll p[maxn],q[maxn];// p[i]:脱落概率,q[i]=(1-p[i])/p[i];
ll a[maxn],b[maxn],d[maxn],s[maxn];
//a[u]:且所有邻居都脱落的概率;b[u]:恰有一个邻居未脱落的概率。
//d[u]:自己未脱落,且恰好有一个邻居未脱落的概率;s[u]:u 及其直属儿子的 d 之和。
int fa[maxn];
void dfs(int u,int ff){
fa[u]=ff,s[u]=d[u],qadd(s[ff],d[u]);
for(int v:ve[u]) if(v^ff) dfs(v,u);
}
void prework(){
For(u,1,n){
a[u]=1ll,b[u]=0ll;
for(int v:ve[u]) a[u]=a[u]*p[v]%mod;
for(int v:ve[u]) qadd(b[u],a[u]*q[v]%mod);
d[u]=b[u]*(mod+1-p[u])%mod;
}
dfs(1,0);
}
ll solve1(){
ll res=0;
For(u,1,n) for(int v:ve[u]) if(u<v)
qadd(res,a[u]*a[v]%mod*q[u]%mod*q[v]%mod);
return res;
}
vector<ll> va;
ll calc(){
ll suma=0,sumb=0;
for(ll x:va) qadd(suma,x),qadd(sumb,x*x%mod);
ll res=(suma*suma%mod-sumb+mod)%mod*inv2%mod;
return res;
}
ll solve2(){
ll res=0;
For(z,1,n)
{
va.clear();
for(int v:ve[z]) va.pb(a[v]*(mod+1-p[v])%mod*inv(p[z])%mod);
qadd(res,calc()*(mod+1-p[z])%mod);
va.clear();
for(int v:ve[z]){
ll bef=(b[v]-a[v]*q[z]%mod+mod)%mod;
va.pb(bef*(mod+1-p[v])%mod);
}
qadd(res,calc()*inv(p[z])%mod);
}
return res;
}
ll solve3()
{
ll res=0,sum=0;
For(u,1,n) qadd(sum,d[u]);
For(u,1,n){
ll tmp=(s[fa[u]]+d[fa[fa[u]]])%mod;
for(int v:ve[u]) if(v^fa[u]) qadd(tmp,s[v]);
tmp=(sum-tmp+mod)%mod,qadd(res,tmp*d[u]%mod);
}
res=res*inv2%mod;
return res;
}
void solve()
{
rd(n);
For(i,1,n){
ll pp,qq;rd(pp,qq);
p[i]=pp*inv(qq)%mod;
q[i]=inv(p[i])*(mod+1-p[i])%mod;
}
For(i,2,n){
int u,v;rd(u,v);
ve[u].pb(v),ve[v].pb(u);
}
prework();
ll res=(solve1()+solve2()+solve3())%mod;
write(res),End;
}
void clear(){ For(i,0,n) ve[i].clear(),s[i]=0; }
signed main(){
int T;rd(T);while(T--) solve(),clear();
return 0;
}

浙公网安备 33010602011771号