最后防线 解题报告

简要题意

给定一棵 \(n\) 个节点的有根树,每个点有点权 \(a\)。求出一个访问顺序,使得所有点都在其祖先之后被访问,设第 \(i\) 个节点是第 \(p_i\) 个被访问到的,最小化 \(\sum \limits_{1\ le i \le n}p_ia_i\)

数据范围:\(n \le 2\times 10^4\)

分析

首先考虑没有访问顺序限制怎么做。那么就贪心地先访问点权最大的点即可。

那么加入访问顺序呢?我们考虑当前权值最大的点,那么我们一定要尽可能优先访问它,但是它再先也不能超过它父亲。总之就是这个点一定在它父亲下一个被访问,那么我们就可以考虑合并这两个点,然后就变成子问题。

那么我们考虑合并要合并什么:权值和个数。

那么怎么确定谁先谁后呢?考虑相邻两个点交换后对总代价的影响,这是 naive 的。

最后用一个支持删除、插入、访问最小元素的数据结构维护即可。

代码

#include<bits/stdc++.h>
#define inf 0x3f3f3f3f
#define Inf (1ll<<60)
#define For(i,s,t) for(int i=s;i<=t;++i)
#define Down(i,s,t) for(int i=s;i>=t;--i)
#define ls (i<<1)
#define rs (i<<1|1)
#define bmod(x) ((x)>=p?(x)-p:(x))
#define lowbit(x) ((x)&(-(x)))
#define End {printf("NO\n");exit(0);}
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
inline void ckmx(int &x,int y){x=(x>y)?x:y;}
inline void ckmn(int &x,int y){x=(x<y)?x:y;}
inline void ckmx(ll &x,ll y){x=(x>y)?x:y;}
inline void ckmn(ll &x,ll y){x=(x<y)?x:y;}
inline int min(int x,int y){return x<y?x:y;}
inline int max(int x,int y){return x>y?x:y;}
inline ll min(ll x,ll y){return x<y?x:y;}
inline ll max(ll x,ll y){return x>y?x:y;}
char buf[1<<20],*p1,*p2;
#define gc() (p1 == p2 ? (p2 = buf + fread(p1 = buf, 1, 1 << 20, stdin), p1 == p2 ? EOF : *p1++) : *p1++)
#define read() ({\
    int x = 0, f = 1;\
    char c = gc();\
    while(c < '0' || c > '9') f = (c == '-') ? -1 : 1, c = gc();\
    while(c >= '0' && c <= '9') x = x * 10 + (c & 15), c = gc();\
    f * x;\
})
void write(int x){
    if(x>=10) write(x/10);
    putchar(x%10+'0');
}
const int N=2e5+100;
int n,rt,fa[N],nxt[N],f[N],a[N],tot;
int find(int x){return f[x]=(f[x]==x ? x : find(f[x]));}
ll ans;
bool vis[N];
struct Node{
    int u,ed;ll f,t;
    bool operator <(const Node x) const{
        return f*x.t>x.f*t || (f*x.t==t*x.f && u<x.u);
    }
    bool operator >(const Node x) const{
        return f*x.t<x.f*t || (f*x.t==t*x.f && u>x.u);
    }
}b[N];
set<Node> s;
int main()
{
#if !ONLINE_JUDGE
    freopen("line.in","r",stdin);
    freopen("line.out","w",stdout);
#endif 
    n=read(),rt=read();
    For(i,1,n) a[i]=read(),f[i]=i;
    int x,y;
    For(i,2,n) x=read(),y=read(),fa[y]=x;
    For(i,1,n) s.insert(b[i]=Node{i,i,a[i],1});
    vis[0]=true;
    while(!s.empty()){
        Node tp=*s.begin();s.erase(tp);
        int u=tp.u,p=find(fa[u]);
        if(vis[fa[u]]){
            ans+=(tot+1)*b[u].f;
            tot+=b[u].t;
            while(u) vis[u]=true,u=nxt[u];
            continue;
        }
        s.erase(b[p]);
        ans+=b[u].f*b[p].t;
        b[p].f+=b[u].f;
        b[p].t+=b[u].t;
        nxt[b[p].ed]=u;
        f[u]=b[p].u;
        b[p].ed=b[u].ed;
        s.insert(b[p]);
    }
    printf("%lld",ans);
    return 0;
}
posted @ 2025-10-16 16:32  XiaoZi_qwq  阅读(1)  评论(0)    收藏  举报