题解:CF2071E LeaFall

分类讨论部分参考了现有的一篇题解(怎么想不到呢),自己推了式子。

一个点成为叶子的充要条件是它自己没有被删除,且它的邻居恰好有一个没有被删除。

对下文变量的一些声明:

  • \(p_i\) 表示 \(i\) 脱落的概率,并令 \(q_i=\frac{1-p_i}{p_i}\)
  • \(a_u\) 表示 \(u\) 的所有邻居都脱落的概率,\(a_u=\Pi p_v\)
  • \(b_u\) 表示 \(u\) 恰有一个邻居未脱落,其他邻居都脱落的概率,\(b_u=\sum _v a_u q_v\)
  • \(d_u\) 表示 \(u\) 恰有一个邻居未脱落,其他邻居都脱落,且 \(u\) 本身未脱落的概率,\(d_u=(1-p_u)b_u\)
  • \(s_u\) 表示固定根节点时, \(u\) 及其直属儿子的 \(d\) 值之和,\(s_u=d_u + \sum_{v \in son(u)} d_v\)

将点对分类,分别计算贡献:

对于 \(dis(u,v)=1\)

此时两点互为邻居,枚举树上每一条边,贡献为 \(a_u a_v q_u q_v\)。这样,第一类贡献就为每条树边的贡献之和。复杂度 \(O(n)\)

对于 \(dis(u,v)=2\)

此时 \(u\)\(v\) 有一个公共邻居 \(z\),枚举 \(z\),对每个 \(z\) 统计它邻居两两贡献。

下式前半部分对应 \(z\) 未脱落的情况,此时 \(u,v\) 的其他邻居必须全部脱落,且 \(u,v\) 本身不脱落;后半部分对应 \(z\) 脱落的情况,\(u\)\(v\) 都恰好剩下一个未脱落的邻居,且 \(u,v\) 本身不脱落。对于某个 \(z\) 的贡献:

\[f(z)=(1-p_z)\sum_{u,v} (1-p_u) \frac{a_u}{p_z} \times (1-p_v) \frac{a_v}{p_z}+ p_z\sum_{u,v} (1-p_u)\frac{b_u-a_uq_z}{p_z} (1-p_v) \frac{b_v-a_vq_z}{p_z} \]

这样子复杂度是 \(O(deg_z^2)\) 的,但是由 \(2 \times \sum_{i<j} x_ix_j = (\sum x_i)^2-\sum x_i^2\) 可以非常容易地优化为 \(O(deg_z)\)。我是另外开了一个函数算这种形式的式子。

最后把所有 \(f(z)\) 累加,就得到了第二类贡献。复杂度 \(O(n)\)

对于 \(dis(u,v)>2\)

计算第三类贡献:

\[ans=\sum_{dis(u,v)>2} d_ud_v \]

这样不太好处理,转化为对每个点求贡献,并转化为总贡献减去不合法(距离太小)的贡献,设 \(sum=\sum_{i=1}^{n} d_i\)

\[\begin{equation*} \begin{aligned} ans&=\frac{1}{2}\sum_{u=1}^n \sum_{dis(u,v)>2} d_u d_v \\ &=\frac{1}{2} \sum_{u=1}^n d_u (sum- \sum_{dis(u,v) \leq 2} d_v) \end{aligned} \end{equation*}\]

考虑如何 \(O(deg_u)\) 获取里面的和式:

\[\sum_{dis(u,v) \leq 2} b_v=b_{fa_{fa_u}} + s_{fa_u}+\sum_{v\in son(u)} s_v \]

第一项即 \(u\) 父亲的父亲,第二项包括 \(u\) 的父亲以及 \(u\) 的父亲的儿子们(包括 \(u\)),第三项为 \(u\) 的儿子与孙子(即儿子的儿子)。

复杂度 \(O(n)\)

上面的式子是写完代码之后重新推的,虽然检查过但是没准还是有错,欢迎在评论区捉虫。

#include<bits/stdc++.h>
#define Spc putchar(' ')
#define End putchar('\n')
#define For(i,il,ir) for(int i=(il);i<=(ir);++i)
#define Fr(i,il,ir) for(int i=(il);i<(ir);++i)
#define Forr(i,ir,il) for(int i=(ir);i>=(il);--i)
#define ForE(u) for(int i=head[u];~i;i=e[i].nxt)
#define fi first
#define se second
#define mk make_pair
#define pb emplace_back
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
namespace _TvT_{
    template<typename T>
    inline void rd(T& x){
        bool f=0;x=0;char ch=getchar();
        while(ch<'0'||ch>'9'){ if(ch=='-') f=1; ch=getchar(); }
        while(ch>='0'&&ch<='9') x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
        if(f) x=-x;
    }
    template<typename T,typename... Args>
    void rd(T& first,Args&... args){ rd(first),rd(args...); }
    int write_num[50];
    template<typename T>
    inline void write(T x){
        int len=0;
        if(x<0) putchar('-'),x=-x;
        do write_num[len++]=x%10; while(x/=10);
        while(len--) putchar(write_num[len]+'0');
    }
    template<typename T,typename... Args>
    void write(T first,Args... args){ write(first),Spc,write(args...); }
}using namespace _TvT_;
const int maxn=1e5+10;
const ll mod=998244353;
const ll inv2=499122177;
void qadd(ll &x,ll y){ x=(x+y>=mod)?(x+y-mod):(x+y); }

ll qp(ll x,ll b){ ll res=1ll; x%=mod; for(;b;x=x*x%mod,b>>=1) if(b&1) res=res*x%mod; return res; }
ll inv(ll x){ return qp(x,mod-2); }

int n;
vector<int> ve[maxn];

ll p[maxn],q[maxn];// p[i]:脱落概率,q[i]=(1-p[i])/p[i];
ll a[maxn],b[maxn],d[maxn],s[maxn];
//a[u]:且所有邻居都脱落的概率;b[u]:恰有一个邻居未脱落的概率。
//d[u]:自己未脱落,且恰好有一个邻居未脱落的概率;s[u]:u 及其直属儿子的 d 之和。
int fa[maxn];
void dfs(int u,int ff){
    fa[u]=ff,s[u]=d[u],qadd(s[ff],d[u]);
    for(int v:ve[u]) if(v^ff) dfs(v,u);
}
void prework(){
    For(u,1,n){
        a[u]=1ll,b[u]=0ll;
        for(int v:ve[u]) a[u]=a[u]*p[v]%mod;
        for(int v:ve[u]) qadd(b[u],a[u]*q[v]%mod);
        d[u]=b[u]*(mod+1-p[u])%mod;
    }
    dfs(1,0);
}
ll solve1(){
    ll res=0;
    For(u,1,n) for(int v:ve[u]) if(u<v)
        qadd(res,a[u]*a[v]%mod*q[u]%mod*q[v]%mod);
    return res;
}

vector<ll> va;
ll calc(){
    ll suma=0,sumb=0;
    for(ll x:va) qadd(suma,x),qadd(sumb,x*x%mod);
    ll res=(suma*suma%mod-sumb+mod)%mod*inv2%mod;
    return res;
}
ll solve2(){
    ll res=0;
    For(z,1,n)
    {
        va.clear();
        for(int v:ve[z]) va.pb(a[v]*(mod+1-p[v])%mod*inv(p[z])%mod);
        qadd(res,calc()*(mod+1-p[z])%mod);

        va.clear();
        for(int v:ve[z]){
            ll bef=(b[v]-a[v]*q[z]%mod+mod)%mod;
            va.pb(bef*(mod+1-p[v])%mod);
        }
        qadd(res,calc()*inv(p[z])%mod);
    }
    return res;
}
ll solve3()
{
    ll res=0,sum=0;
    For(u,1,n) qadd(sum,d[u]);
    For(u,1,n){
        ll tmp=(s[fa[u]]+d[fa[fa[u]]])%mod;
        for(int v:ve[u]) if(v^fa[u]) qadd(tmp,s[v]);
        tmp=(sum-tmp+mod)%mod,qadd(res,tmp*d[u]%mod);
    }
    res=res*inv2%mod;
    return res;
}
void solve()
{
    rd(n);
    For(i,1,n){
        ll pp,qq;rd(pp,qq);
        p[i]=pp*inv(qq)%mod;
        q[i]=inv(p[i])*(mod+1-p[i])%mod;
    }
    For(i,2,n){
        int u,v;rd(u,v);
        ve[u].pb(v),ve[v].pb(u);
    }
    prework();
    ll res=(solve1()+solve2()+solve3())%mod;
    write(res),End;
}
void clear(){ For(i,0,n) ve[i].clear(),s[i]=0; }
signed main(){
    int T;rd(T);while(T--) solve(),clear();
    return 0;
}
posted @ 2025-03-10 22:14  wanggk  阅读(46)  评论(0)    收藏  举报