【题解】abc401_f Add One Edge 3

abc401_f Add One Edge 3

简要题意

给定两棵树 \(T_1,T_2\),点数分别为 \(N_1,N_2\)

定义 \(f(i,j)\) 表示将 \(T_1\) 上的点 \(i\)\(T_2\) 上的点 \(j\) 连一条边后得到的新树的直径。

\(\displaystyle \sum_{i=1}^{N_1}\sum_{j=1}^{N_2} f(i,j)\)

\(1\le N_1,N_2\le 2\times 10^5\)

题解

知识点:树的直径,图论,树状数组

定义 \(d_{1,i}\) 表示 \(T_1\)\(i\) 为根后树的最大深度,\(d_{2,i}\) 同理,\(len_1\) 表示 \(T_1\) 的直径,\(len_2\) 同理,\(D=\max\{len_1,len_2\}\)

\(f(i,j)=\max\{D,d_{1,i}+d_{2,j}+1\}\)

显然当 \(d_{1,i}+d_{2,j}+1\le D\) 时,\(f(i,j)=D\)

所以 \(d_{1,i}+d_{2,j}+1\) 有贡献当且仅当 \(d_{1,i}+d_{2,j}+1>D\),即 \(d_{1,j}>D-1-d_{2,i}\)

\(d_{1}\)\(d_{2}\) 是好求的,直接 dp 即可,而 \(len_1,len_2\) 更不必说了。

先把 \(T_1\) 上的每一个点的 \(d_{1,i}\) 丢到值域树状数组上。

枚举 \(T_2\) 上的点 \(j\),查询树状数组中满足上述不等式的 \(d_{1,i}\) 之和,加到贡献中,还要查询不满足上述不等式的点的数量 \(c\),将 \(c\times D\) 加到贡献中,就做完了。

复杂度 \(O(n\log_2 n)\)

启发:

  • 两棵树合并后新的直径的求法。

  • 善用结构体封装。

忘记给树状数组的 \(n\) 赋初值了,导致了赛后过题的悲剧,下次一定要注意。

#include<bits/stdc++.h>
using namespace std;

#define rep(i,l,r) for(int i=(l);i<=(r);++i)
#define per(i,l,r) for(int i=(r);i>=(l);--i)
#define pr pair<int,int>
#define fi first
#define se second
#define pb push_back
#define all(x) (x).begin(),(x).end()
#define sz(x) (x).size()

#define N 202504
#define int long long

struct TREE{
    vector<int>e[N];
    int n,len;
    int dw1[N],dw2[N],up[N],mx[N];
    int d1[N],d2[N];

    inline void input(){
        cin>>n;
        rep(i,1,n-1) {
            int u,v;
            cin>>u>>v;
            e[u].pb(v);
            e[v].pb(u);
        }
    }

    inline void bfs(int st,int *d){
        queue<int>q;
        fill(d+1,d+n+1,-1);

        d[st]=0;
        q.push(st);

        while(!q.empty()){
            int u=q.front();
            q.pop();

            for(int v:e[u]){
                if(d[v]==-1){
                    d[v]=d[u]+1;
                    q.push(v);
                }
            }
        }
    }

    inline void get_len(){
        bfs(1,d1);
        int u=max_element(d1+1,d1+n+1)-d1;
        bfs(u,d2);
        int v=max_element(d2+1,d2+n+1)-d2;
        len=d2[v];
    }

    inline void dfs_dw(int u,int fa){

        dw1[u]=dw2[u]=0;
        for (int v:e[u]){
            if (v!=fa){
                dfs_dw(v,u);
                int d=dw1[v]+1;
                if(d>dw1[u]){
                    dw2[u]=dw1[u];
                    dw1[u]=d;
                }
                else if(d>dw2[u]){
                    dw2[u]=d;
                }
            }
        }
    }

    inline void dfs_up(int u,int fa){
        for(int v:e[u]){
            if (v!=fa){
                up[v]=up[u]+1;
                if(dw1[v]+1==dw1[u]){
                    up[v]=max(up[v],dw2[u]+1);
                }
                else{
                    up[v]=max(up[v],dw1[u]+1);
                }
                dfs_up(v,u);
            }
        }
        mx[u]=max(dw1[u],up[u]);
    }

    inline void calc(){
        dfs_dw(1,0);
        up[1]=0;
        dfs_up(1,0);
    }
}t1,t2;

struct BIT{
    int tr[N],lim;
    #define lb(x) (x&-x)

    inline void upd(int k,int d){
        while(k<=lim){
            tr[k]+=d;
            k+=lb(k);
        }
    }

    inline int ask(int k){
        if(k<0){
            return 0;
        }
        int ans=0;

        while(k){
            ans+=tr[k];
            k-=lb(k);
        }

        return ans;
    }
}t,cnt;

signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);

    t1.input();
    t2.input();

    t1.get_len();
    t2.get_len();

    t1.calc();
    t2.calc();

    cnt.lim=t.lim=max(t1.n,t2.n);

    rep(i,1,t1.n){
        cnt.upd(t1.mx[i],1);
        t.upd(t1.mx[i],t1.mx[i]);
    }

    int ans=0,D=max(t1.len,t2.len);

    rep(i,1,t2.n){
        int c=cnt.ask(D-t2.mx[i]-1);

        ans=(ans+c*D);

        c=t1.n-c;
        int sum=t.ask(t.lim)-t.ask(D-t2.mx[i]-1);

        ans=(ans+c*(t2.mx[i]+1)+sum);
    }

    cout<<ans;

    return 0;
}
posted @ 2025-04-26 11:46  Lucyna_Kushinada  阅读(32)  评论(0)    收藏  举报