树的合并 connect

树的合并 connect

题目描述

 

话说moreD经过不懈努力,终于背完了循环整数,也终于完成了他的蛋糕大餐。

         但是不幸的是,moreD得到了诅咒,受到诅咒的原因至今无人知晓。

         moreD在发觉自己得到诅咒之后,决定去寻找闻名遐迩的术士CD帮忙。

         话说CD最近在搞OI,遇到了一道有趣的题目:

         给定两棵树,则总共有N*M种方案把这两棵树通过加一条边连成一棵树,那这N*M棵树的直径大小之和是多少呢?

        

         CD为了考验moreD是否值得自己费心力为他除去诅咒,于是要他编程回答这个问题,但是这moreD早就被诅咒搞晕了头脑,就只好请你帮助他了。

 

 

输入

 

第一行两个正整数N,M,分别表示两棵树的大小。
接下来N-1行,每行两个正整数ai,bi,表示第一棵树上的边。
接下来M-1行,每行两个正整数ci,di,表示第二棵树上的边。

 

 

输出

 

一行一个整数,表示答案。

 

 

样例输入

4 3
1 2
2 3
2 4
1 3
2 3

样例输出

53

提示

 

【数据范围】

对于20%的数据满足N<=300,M<=300

对于50%的数据满足N,M<=3000

对于100%的数据满足N<=10^5,M<=10^5,1<=ai,bi<=N,1<=ci,di<=M

 

【提示】

         树的直径指的是树上的最长简单路径。

 

 

来源

1019


挺好想的题

令a[i]表示i开头第一棵树上最长链

b[i]表示i开头第二棵树上最长链

Max为两棵树直径的max

对于a[i]和b[j]

贡献为max(Max,a[i]+b[j]+1)

拿个指针扫扫就行了

#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#define maxn 100005
using namespace std;
int n,m,tot,head[maxn],f[maxn],g[maxn],dp[maxn],t1,t2;
long long a[maxn],b[maxn],sum[maxn],ans,ma;
struct node{
    int v,nex;
}e[maxn*2];
void lj(int t1,int t2){
    tot++;e[tot].v=t2;e[tot].nex=head[t1];head[t1]=tot;
}
void dfs1(int k,int fa){
    int Max=-1e9,max2=-1e9;
    for(int i=head[k];i;i=e[i].nex){
        if(e[i].v!=fa){
            dfs1(e[i].v,k);
            if(f[e[i].v]>Max){
                max2=max(max2,Max);
                Max=f[e[i].v];
            }
            else max2=max(max2,f[e[i].v]);
        }
    }
    if(Max==-1e9)f[k]=0,g[k]=-1e9;
    else {f[k]=Max+1;g[k]=max2+1;}
    //cout<<k<<' '<<f[k]<<' '<<g[k]<<endl;
}
void dfs2(int k,int fa){
    for(int i=head[k];i;i=e[i].nex){
        if(e[i].v!=fa){
            dp[e[i].v]=max(dp[e[i].v],dp[k]+1);
            if(f[e[i].v]==f[k]-1){
                dp[e[i].v]=max(dp[e[i].v],g[k]+1);
            }
            else dp[e[i].v]=max(dp[e[i].v],f[k]+1);
            dfs2(e[i].v,k);
        }
    }
}
void Q(){
    for(int i=1;i<=n;i++)head[i]=f[i]=g[i]=dp[i]=0;
    tot=0;
}
int main()
{
    cin>>n>>m;
    for(int i=1;i<n;i++){
        scanf("%d%d",&t1,&t2);
        lj(t1,t2);lj(t2,t1);
    }
    dfs1(1,0);dfs2(1,0);
    for(int i=1;i<=n;i++)a[i]=max(f[i],dp[i]);
    Q();
    for(int i=1;i<m;i++){
        scanf("%d%d",&t1,&t2);
        lj(t1,t2);lj(t2,t1);
    }
    dfs1(1,0);dfs2(1,0);
    for(int i=1;i<=m;i++){
        b[i]=max(f[i],dp[i]);
    }
    sort(a+1,a+n+1);sort(b+1,b+m+1);
    for(int i=1;i<=n;i++)sum[i]=sum[i-1]+a[i];
    ma=max(a[n],b[m]);
    int l=1;
    for(int i=m;i>=1;i--){
        while(b[i]+a[l]+1<ma&&l<=n)l++;
        long long num=n-l+1;
        long long tmp=b[i]*num;tmp+=sum[n]-sum[l-1];tmp+=num;
        tmp+=ma*(n-num);
        ans+=tmp;
    }
    cout<<ans<<endl;
    return 0;
}
 

 

posted @ 2018-08-30 19:39  liankewei123456  阅读(262)  评论(0编辑  收藏  举报