Loading

Atcoder_cf17_final_j Tree MST

这是我的第一道黑题!

言归正传,题意是,给定一棵 \(n\) 个节点的树,现有有一张完全图,两点 \(x\),\(y\) 之间的边长为 \(w_x+w_y+dis_{x,y}\),其中 \(dis_{x,y}\) 表示 \(x\)\(y\) 在树上的距离,求完全图的最小生成树。

常规的求最小生成树的算法有 \(kruskal\)\(prim\)。但是这里这张图是完全图,两个算法都会超时的。

所以一个为这个问题量身定做的算法出现了!它就是 \(boruvka\)

算法流程

每轮为当前每个连通块找到与其最近的连通块,并连边,直到只有一个连通块。

正确性

最后的最小生成树上的每个点,显然都会保留它连出的最短的边。

否则断掉现在它连出的一条边,再连最短的边一定更优。

那么每轮过后,把一个连通块缩成一个点,按照上面的结论一直做下去就是对的。

轮数

每次连通块个数至少除以二,因为最坏情况下的连边就是1-2,3-4,5-6,... \(n-1\) -\(n\)

所以最多有 \(\log n\) 轮。

找最短边

做法挺巧妙的。

我们先算出每个点和不同连通块的点的最短边长,以及最近点在哪个块,然后合并到这个点的连通块上去。

\(sum_i\) 表示点 \(i\) 到根节点的树上距离,\(a_i\) 表示点 \(i\) 的权值。

考虑树形 \(dp\),先 \(dfs\) 一遍找出子树内最近的点(其实就是子树里 \(sum_v+a_v\) 最小的点,因为 \(u\)\(v\) 的边长是 \(a_u+a_v+sum_v-sum_u\))。

但是!这个点可能是在同一块里的,怎么办呢?

我们可以不只维护最近的点,另外再维护一个和当前最近点不是同一块的次近点。

这样的话,如果最近点在同一块里,次近点就是答案了。

第二遍 \(dfs\),我们求出子树外最近的点,子树外的点到 \(u\) 的边长是 \(a_u+a_v+sum_v+sum_u-2\times sum_{lca}\)。当 \(lca\) 固定的时候,这个 \(v\) 一定是上一次 \(dfs\) 里算出的 \(lca\) 的子树中, \(sum_v+a_v\) 最小的点、和最小点不在一个块的次小点。那么维护从根节点到 \(u\)\(a_v+sum_v-2\times sum_{lca}\) 最小的点,和不在同一块的次小点就好了。

代码:

#include<bits/stdc++.h>
#define int long long
#define mkp make_pair
#define fi first
#define se second
using namespace std;
const int N=2e5+10;
int n,cnt,res,a[N],f[N],s[N];
int idx,hd[N],to[N<<1],nxt[N<<1],len[N<<1];
pair<int,int>p,s1[N],s2[N],ans[N];
int find(int x)
{
    if(f[x]==x)return x;
    return f[x]=find(f[x]);
}
void add(int u,int v,int w)
{
    ++idx,to[idx]=v,nxt[idx]=hd[u],len[idx]=w,hd[u]=idx;
    return;
}
pair<int,int>cmx(pair<int,int>a,pair<int,int>b,pair<int,int>c,int f)
{
    p=mkp(1e18,-1);
    if(a.se!=f)p=min(p,a);
    if(b.se!=f)p=min(p,b);
    if(c.se!=f)p=min(p,c);
    return p;
}
void dfs1(int u,int fa)
{
    s1[u]=mkp(s[u]+a[u],f[u]);
    s2[u]=mkp(1e18,-1);
    for(int i=hd[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==fa)continue;
        s[v]=s[u]+len[i];
        dfs1(v,u);
        if(s1[u]>s1[v])s2[u]=cmx(s1[u],s2[u],s2[v],s1[v].se),s1[u]=s1[v];
        else s2[u]=cmx(s2[u],s1[v],s2[v],s1[u].se);
    }
    return;
}
void dfs2(int u,int fa)
{
    s1[u].fi-=2*s[u],s2[u].fi-=2*s[u];
    if(fa)if(s1[u]>s1[fa])s2[u]=cmx(s2[fa],s1[u],s2[u],s1[fa].se),s1[u]=s1[fa];
    else s2[u]=cmx(s2[u],s1[fa],s2[fa],s1[u].se);
    if(s1[u].se!=f[u])ans[f[u]]=min(ans[f[u]],mkp(s1[u].fi+s[u]+a[u],s1[u].se));
    else if(s2[u].se!=f[u])ans[f[u]]=min(ans[f[u]],mkp(s2[u].fi+s[u]+a[u],s2[u].se));
    for(int i=hd[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==fa)continue;
        dfs2(v,u);
    }
    return;
}
signed main()
{
    scanf("%lld",&n);
    for(int i=1;i<=n;i++)scanf("%lld",&a[i]);
    for(int i=1,u,v,w;i<n;i++)scanf("%lld%lld%lld",&u,&v,&w),add(u,v,w),add(v,u,w);
    for(int i=1;i<=n;i++)f[i]=i;
    cnt=n;
    while(cnt>1)
    {
        dfs1(1,0);
        for(int i=1;i<=n;i++)ans[i]=mkp(1e18,-1);
        dfs2(1,0);
        for(int i=1;i<=n;i++)
        {
            if(f[i]==i&&ans[i].se>0&&find(i)!=find(ans[i].se))f[i]=ans[i].se,res+=ans[i].fi;
        }
        cnt=0;
        for(int i=1;i<=n;i++)cnt+=(find(i)==i);
    }
    printf("%lld",res);
    return 0;
}
posted @ 2024-12-25 09:03  AvisD  阅读(29)  评论(0)    收藏  举报