题解:[ABC314F] A Certain Game
解题思路
很容易想到并查集可以维护集合合并,每一个集合用 vector 存储集合中的元素,并记录集合大小 \(siz\)。
对于每一次集合 \(x,y\) 的比赛,集合 \(x\) 中所有元素的获胜期望增加 \(\dfrac{siz_x}{siz_x+siz_y}\),集合 \(y\) 中所有元素的获胜期望增加 \(\dfrac{siz_y}{siz_x+siz_y}\)。
每一次合并集合 \(x\) 和 \(y\),暴力将 \(y\) 中所有元素加入 \(x\),更新 \(siz_x=siz_x+siz_y\)。
对于每一个人,他打多少次比赛就更新多少次答案,所以复杂度最劣情况为 \(O(n^2)\)。
考虑优化,发现可以给集合打 tag 更新,最后所有元素加上最终集合 tag。
每个集合维护一个 \(sum\)。合并集合 \(x\) 和 \(y\) 时,\(sum_x=sum_x+\dfrac{siz_x}{siz_x+siz_y}\),\(sum_y=sum_y+\dfrac{siz_x}{siz_x+siz_y}\)。
将 \(y\) 中答案合并到 \(x\) 中时,原本答案在最后应 \(+sum_y\),但直接合并最后就变成 \(+sum_x\),因此合并答案时应将答案 \(+sum_y-sum_x\)。
打了 tag 后,合并复杂度为 \(O(siz_y)\),总复杂度卡到最劣还是 \(O(n^2)\)。发现已经可以按秩合并(将 \(siz\) 小集合的合并到大的集合中),按秩合并总复杂度 \(O(n\log n)\)。
复杂度证明:对于每一个人,他所在的集合合并时若将这个人暴力合并,集合大小至少 \(\times2\),集合大小最大为 \(n\),每个人最多暴力合并 \(\log n\) 次。
代码
#include<map>
#include<set>
#include<cmath>
#include<ctime>
#include<queue>
#include<stack>
#include<cstdio>
#include<vector>
#include<string>
#include<bitset>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
using namespace std;
const int MAXN=2e5+10;
const int N=2e5;
const int INF=0x3f3f3f3f;
const long long LINF=0x3f3f3f3f3f3f3f3f;
const int mod=998244353;
int n;
int pre[MAXN],siz[MAXN];
int inv[MAXN];
int Ans[MAXN];
int find(int x){
if(pre[x]==x){
return x;
}
pre[x]=find(pre[x]);
return pre[x];
}
vector <int> vec[MAXN];
vector <int> ans[MAXN];
int sum[MAXN];
void join(int x,int y){
for(int i=0;i<siz[y];i++)//暴力将y中的元素合并到x中
{
vec[x].push_back(vec[y][i]);
ans[x].push_back((1ll*ans[y][i]+mod+sum[y]-sum[x])%mod);
}
siz[x]+=siz[y];
pre[y]=x;
}
void init(){
inv[1]=1;
for(int i=2;i<=n;i++)//预处理逆元
{
inv[i]=1ll*inv[mod%i]*(mod-mod/i)%mod;
}
for(int i=1;i<=n;i++)
{
vec[i].push_back(i);
ans[i].push_back(0);
pre[i]=i;
siz[i]=1;
}
}
signed main(){
scanf("%d",&n);
init();
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
x=find(x);
y=find(y);
if(siz[x]<siz[y]){//按秩合并优化
swap(x,y);
}
//更新sum
sum[x]+=1ll*siz[x]*inv[siz[x]+siz[y]]%mod;
sum[x]%=mod;
sum[y]+=1ll*siz[y]*inv[siz[x]+siz[y]]%mod;
sum[y]%=mod;
join(x,y);
}
int tot=find(1);//最终集合
for(int i=0;i<n;i++)
{
Ans[vec[tot][i]]=(ans[tot][i]+sum[tot])%mod;
}
for(int i=1;i<=n;i++)
{
printf("%d ",Ans[i]);
}
return 0;
}

浙公网安备 33010602011771号