Loading

题解:[ABC314F] A Certain Game

解题思路

很容易想到并查集可以维护集合合并,每一个集合用 vector 存储集合中的元素,并记录集合大小 \(siz\)

对于每一次集合 \(x,y\) 的比赛,集合 \(x\) 中所有元素的获胜期望增加 \(\dfrac{siz_x}{siz_x+siz_y}\),集合 \(y\) 中所有元素的获胜期望增加 \(\dfrac{siz_y}{siz_x+siz_y}\)

每一次合并集合 \(x\)\(y\),暴力将 \(y\) 中所有元素加入 \(x\),更新 \(siz_x=siz_x+siz_y\)

对于每一个人,他打多少次比赛就更新多少次答案,所以复杂度最劣情况为 \(O(n^2)\)

考虑优化,发现可以给集合打 tag 更新,最后所有元素加上最终集合 tag。

每个集合维护一个 \(sum\)。合并集合 \(x\)\(y\) 时,\(sum_x=sum_x+\dfrac{siz_x}{siz_x+siz_y}\)\(sum_y=sum_y+\dfrac{siz_x}{siz_x+siz_y}\)

\(y\) 中答案合并到 \(x\) 中时,原本答案在最后应 \(+sum_y\),但直接合并最后就变成 \(+sum_x\),因此合并答案时应将答案 \(+sum_y-sum_x\)

打了 tag 后,合并复杂度为 \(O(siz_y)\),总复杂度卡到最劣还是 \(O(n^2)\)。发现已经可以按秩合并(将 \(siz\) 小集合的合并到大的集合中),按秩合并总复杂度 \(O(n\log n)\)

复杂度证明:对于每一个人,他所在的集合合并时若将这个人暴力合并,集合大小至少 \(\times2\),集合大小最大为 \(n\),每个人最多暴力合并 \(\log n\) 次。

代码

#include<map>
#include<set>
#include<cmath>
#include<ctime>
#include<queue>
#include<stack>
#include<cstdio>
#include<vector>
#include<string>
#include<bitset>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
using namespace std;
const int MAXN=2e5+10;
const int N=2e5;
const int INF=0x3f3f3f3f;
const long long LINF=0x3f3f3f3f3f3f3f3f;
const int mod=998244353;
int n;
int pre[MAXN],siz[MAXN];
int inv[MAXN];
int Ans[MAXN];
int find(int x){
    if(pre[x]==x){
        return x;
    }
    pre[x]=find(pre[x]);
    return pre[x];
}
vector <int> vec[MAXN];
vector <int> ans[MAXN];
int sum[MAXN];
void join(int x,int y){
    for(int i=0;i<siz[y];i++)//暴力将y中的元素合并到x中
    {
        vec[x].push_back(vec[y][i]);
        ans[x].push_back((1ll*ans[y][i]+mod+sum[y]-sum[x])%mod);
    }
    siz[x]+=siz[y];
    pre[y]=x;
}
void init(){
    inv[1]=1;
    for(int i=2;i<=n;i++)//预处理逆元
    {
        inv[i]=1ll*inv[mod%i]*(mod-mod/i)%mod;
    }
    for(int i=1;i<=n;i++)
    {
        vec[i].push_back(i);
        ans[i].push_back(0);
        pre[i]=i;
        siz[i]=1;
    }
}
signed main(){
    scanf("%d",&n);
    init();
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        x=find(x);
        y=find(y);
        if(siz[x]<siz[y]){//按秩合并优化
            swap(x,y);
        }
        //更新sum
        sum[x]+=1ll*siz[x]*inv[siz[x]+siz[y]]%mod;
        sum[x]%=mod;
        sum[y]+=1ll*siz[y]*inv[siz[x]+siz[y]]%mod;
        sum[y]%=mod;
        join(x,y);
    }
    int tot=find(1);//最终集合
    for(int i=0;i<n;i++)
    {
        Ans[vec[tot][i]]=(ans[tot][i]+sum[tot])%mod;
    }
    for(int i=1;i<=n;i++)
    {
        printf("%d ",Ans[i]);
    }
    return 0;
}
posted @ 2025-02-09 07:02  Mathew_Miao  阅读(10)  评论(0)    收藏  举报