第十六届浙江大学宁波理工学院程序设计大赛 E 雷顿女士与平衡树(并查集)

题意

链接:https://ac.nowcoder.com/acm/contest/2995/E
来源:牛客网

卡特莉正在爬树,此时她又在树梢发现了一个谜题,为了不令她分心以至于发生意外,请你帮她解决这个问题。
具体地来说,我们定义树上从u到v简单路径上所有点权中最大值与最小值的差值为这条路径的"平衡值",记为balance(u,v)。

 

 

 

思路

首先,把这个式子拆成两部分,一部分计算最大值的和,另一部分计算最小值的和。

如何计算最大值的和?

将点按点权从小到大排序,从小到大遍历每个点u,记u为访问过,找u连的点v,如果v访问过了,那么说明v肯定比u小,所以u的权值在这两个集合作为最大值,乘上sz[fu]和sz[fv]即为u对答案的贡献。可能会有疑问为啥要乘上sz[fu],乘上sz[fv]不就是u对v所在集合的贡献吗?实则不然,比如看样例:

1
10
9 9 6 2 4 5 8 5 5 6
2 1
3 1
4 3
5 3
6 4
7 2
8 4
9 5
10 3

 

 蓝色表示每个点按顺序更新后所在集合的sz,当遍历到u=2的时候,发现连的点7和1都访问过,说明权值都是小于等于9的,所以将2和7合并后,u所在集合的sz变成了2,那么2再和1合并时,不仅2可以作为1所在集合的点的最大值,而且7这个点是可以陪着2一起连向1所在集合的,因为最大值依然是2这个点。

对于最小值的和的求法是类似的,按权值从大到小排序……

代码

#include<bits/stdc++.h>
using namespace std;
#define inf 0x3f3f3f3f
#define ll long long
const int N=5e5+5;
const int mod=1e9+7;
const double eps=1e-8;
const double PI = acos(-1.0);
#define lowbit(x) (x&(-x))
int read()
{
    int x=0;
    char ch=getchar();
    while(!isdigit(ch))
    {
        ch=getchar();
    }
    while(isdigit(ch))
    {
        x=x*10+(ch-'0');
        ch=getchar();
    }
    return x;
}
struct node
{
    ll x,w;
}a[N];
bool cmp1(node a,node b)
{
    return a.w<b.w;
}
bool cmp2(node a,node b)
{
    return a.w>b.w;
}
vector<int> g[N];
ll pre[N],mx[N],mn[N],sz[N],vis[N];
ll find(ll x)
{
    if(x==pre[x])
        return x;
    return pre[x]=find(pre[x]);
}
int main()
{
    std::ios::sync_with_stdio(false);
    int t=read();
    while(t--)
    {
        int n=read();
        for(int i=1; i<=n; i++)
        {
            g[i].clear();
            a[i].w=read();
            a[i].x=i;
            mx[i]=mn[i]=a[i].w;
        }
        for(int i=1; i<n; i++)
        {
            int u=read(),v=read();
            g[u].push_back(v);
            g[v].push_back(u);
        }
        sort(a+1,a+n+1,cmp1);
        for(int i=1; i<=n; i++)
            pre[i]=i,sz[i]=1,vis[i]=0;
        ll ans1=0;
        for(int i=1; i<=n; i++)
        {
            ll u=a[i].x,fu=find(u);
            vis[u]=1;
            for(int j:g[u])
            {
                if(!vis[j]) continue;
                int fj=find(j);
                pre[fj]=fu;
                ans1=(ans1+sz[fu]*sz[fj]%mod*a[i].w%mod)%mod;
                (sz[fu]+=sz[fj])%=mod;
            }
            //     cout<<u<<" "<<sz[fu]<<endl;
        }
        // cout<<ans1<<endl;
        sort(a+1,a+1+n,cmp2);
        for(int i=1; i<=n; i++)
            pre[i]=i,sz[i]=1,vis[i]=0;
        ll ans2=0;
        for(int i=1; i<=n; i++)
        {
            int u=a[i].x,fu=find(u);
            vis[u]=1;
            for(int j:g[u])
            {
                if(!vis[j]) continue;
                int fj=find(j);
                pre[fj]=fu;
                ans2=(ans2+sz[fu]*sz[fj]%mod*a[i].w%mod)%mod;
                (sz[fu]+=sz[fj])%=mod;
            }
            //     cout<<u<<" "<<sz[fu]<<endl;
        }
  //      cout<<ans2<<endl;
        cout<<(ans1%mod-ans2%mod+mod)%mod<<endl;
    }
    return 0;
}

  

posted @ 2019-12-08 20:30  MCQ1999  阅读(250)  评论(0编辑  收藏  举报