树的合并 connect

[Noip模拟题]树的合并

Description

话说moreD经过不懈努力,终于背完了循环整数,也终于完成了他的蛋糕大餐。

但是不幸的是,moreD得到了诅咒,受到诅咒的原因至今无人知晓。

moreD在发觉自己得到诅咒之后,决定去寻找闻名遐迩的术士CD帮忙。

话说CD最近在搞OI,遇到了一道有趣的题目:

给定两棵树,则总共有NM种方案把这两棵树通过加一条边连成一棵树,那这NM棵树的直径

(树的直径指的是树上的最长简单路径)

大小之和是多少呢?

CD为了考验moreD是否值得自己费心力为他除去诅咒,于是要他编程回答这个问题,但是这m

oreD早就被诅咒搞晕了头脑,就只好请你帮助他了。

Input

第一行两个正整数N,M,分别表示两棵树的大小。
接下来N-1行,每行两个正整数ai,bi,表示第一棵树上的边。
接下来M-1行,每行两个正整数ci,di,表示第二棵树上的边。
N<=10^5,M<=10^5,1<=ai,bi<=N,1<=ci,di<=M

Sample Input

4 3
1 2
2 3
2 4
1 3
2 3

Sample Output

53

Solution

这道题主流写法是tree dp, 但我队测时没想到dp, 就口胡了一个算法, 居然A了

首先O(n)求出两颗树各自的直径, 并求出直径的两个端点

因为对于树上任何一点, 离它最远的点一定是直径的两个端点之一

因此用lca求出x与两个端点的距离, 取max得出x与离其最远点的距离len[x]

考虑新树的直径是要么是原先两颗树中较大的直径maxd, 要么是两点相连所形成的新路径tr[0].len[i]+tr[1].len[j]+1

接着将len从小到大排序

那么对于第一颗树中的第 i 个 len, 二分求出第二颗树中的第一个 j,使得tr[0].len[i]+tr[1].len[j]+1>=maxd

那么len[k] (1<=k<j)都满足tr[0].len[i]+tr[1].len[k]+1<maxd

此时新树的直径是原先两颗树中较大的直径maxd

因此贡献为(j-1)*maxd

那么len[k] (j<=k<=m)都满足tr[0].len[i]+tr[1].len[k]+1>=maxd

此时新树的直径是两点相连所形成的新路径tr[0].len[i]+tr[1].len[j]+1

因此贡献为(m-j+1)*tr[0].len[i]+tr[1].sum[m]-tr[1].sum[j-1]+(m-j+1), 其中sum为len的前缀和

所以第一颗树中第 i 个len的贡献为(j-1)*maxd+(m-j+1)*tr[0].len[i]+tr[1].sum[m]-tr[1].sum[j-1]+(m-j+1);

总时间复杂度O(n log n)

#include<bits/stdc++.h>
#define int long long
using namespace std;
int read(){
    int x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
const int N=1e5+28;
struct E{int to,nxt;};
struct Tree{
    int head[N],cnt,fa[N][20],dep[N],len[N],dl,dr,sum[N];
    E l[N<<1];
    void Ins(int x,int y){
        l[++cnt].nxt=head[x];
        l[cnt].to=y;
        head[x]=cnt;
    }
    void Dfs(int x,int f){
        fa[x][0]=f;
        dep[x]=dep[f]+1;
        for(int i=1;i<20;i++)fa[x][i]=fa[fa[x][i-1]][i-1];
        for(int i=head[x];i;i=l[i].nxt){
            int y=l[i].to;
            if(y==f)continue;
            Dfs(y,x);
        }
    }
    int Lca(int x,int y){
        if(dep[x]<dep[y])swap(x,y);
        for(int i=19;i>=0;i--){
            int f=fa[x][i];
            if(dep[f]>=dep[y])x=f;
        }
        if(x==y)return x;
        for(int i=19;i>=0;i--){
            int fx=fa[x][i],fy=fa[y][i];
            if(fx!=fy)x=fx,y=fy;
        }
        return fa[x][0];
    }
    int Dis(int x,int y){return dep[x]+dep[y]-2*dep[Lca(x,y)];}
    int apr[N],dis[N];
    queue<int>q;
    int Spfa(int s){
        memset(apr,0,sizeof(apr));
        memset(dis,0x3f,sizeof(dis));
        q.push(s);
        dis[s]=0;
        apr[s]=1;
        while(q.size()){
            int x=q.front();
            q.pop();
            apr[x]=0;
            for(int i=head[x];i;i=l[i].nxt){
                int y=l[i].to;
                if(dis[x]+1<dis[y]){
                    dis[y]=dis[x]+1;
                    if(!apr[y])q.push(y),apr[y]=1;
                }
            }
        }
    }
    int Calc(int n){
        Dfs(1,0);
        dl=dr=1;
        for(int i=1;i<=n;i++)if(dep[i]>dep[dl])dl=i;
        Spfa(dl);
        for(int i=1;i<=n;i++)if(dis[i]>dis[dr])dr=i;
        for(int i=1;i<=n;i++)len[i]=max(Dis(i,dl),Dis(i,dr));
        sort(len+1,len+n+1);
        for(int i=1;i<=n;i++)sum[i]=sum[i-1]+len[i];
        return dis[dr];
    }
    int Match(int n,int x){
        int l=1,r=n,re=n+1;
        while(l<=r){
            int mid=(l+r)>>1;
            if(len[mid]>=x)re=mid,r=mid-1;
            else l=mid+1;
        }
        return re;
    }
    Tree(){
        cnt=0;
        memset(head,0,sizeof(head));
        memset(dep,0,sizeof(dep));
        memset(fa,0,sizeof(fa));
        memset(len,0,sizeof(len));
        memset(sum,0,sizeof(sum));
        memset(l,0,sizeof(l));
    };
}tr[2];
int n,m,mxd;
signed main(){
//  freopen("connect.in","r",stdin);
//  freopen("connect.out","w",stdout);
    n=read(),m=read();
    for(int i=1,x,y;i<n;i++)tr[0].Ins(x=read(),y=read()),tr[0].Ins(y,x);
    for(int i=1,x,y;i<m;i++)tr[1].Ins(x=read(),y=read()),tr[1].Ins(y,x);
    mxd=max(tr[0].Calc(n),tr[1].Calc(m));
    int ans=0;
    for(int i=1;i<=n;i++){
        int p=tr[1].Match(m,mxd-tr[0].len[i]-1);
        ans+=(p-1)*mxd;
        ans+=(m-p+1)*tr[0].len[i]+tr[1].sum[m]-tr[1].sum[p-1]+(m-p+1);
    }
    printf("%lld",ans);
    return 0;
}
posted @ 2019-06-27 16:43 The_KOG 阅读(...) 评论(...) 编辑 收藏