P3642 [APIO2016] 烟花表演 解题报告

简要题意

给定一颗有根树,边有边权。你可以花费 \(1\) 的代价使任意一条边的边权减一或加一。询问使所有叶子到根的距离相等的最小代价。

分析

首先看上去就很 dp,于是考虑状态设计。设 \(f_{u,i}\) 表示使 \(u\) 子树内的所有叶子到 \(u\) 的距离为 \(i\) 的最小代价。

那么存在转移(其中 \(w_{u,v}\) 是边 \((u,v)\) 的边权,下文简写为 \(w\)):

\[f_{u,i}=\sum_{v \in son_u} \min_{j \le i} f_{v,j}+\lvert w_{u,v}-(i-j)\rvert \]

注意到 dp 转移形如若干个绝对值之和,所以 \(F_u(i)=f_{u,i}\) 是一个凸函数。接下来考虑分类讨论:

我们约定 \(L,R\) 分别为 \(F_v\) 中斜率为 \(0\) 的一段的左右端点。

  • \(i < L\):考虑从 \(j\)\(j-1\) 时函数值的增量,因为在 \((-\infty,L]\) 上斜率小于等于 \(-1\),因此 \(f_{v,i}\) 的增量不小于 \(1\);又因为 \(\lvert w-(i-j) \rvert\) 的增量至少为 \(-1\),因此这个变化不会变优,因此当 \(j=i\) 是函数值最小,有 \(F_u(i)=F_v(i)+w\)

  • \(i \ge L\):此时 \(j\) 可以取 \([L,R]\) 中的值,因此 \(F_v(j)\) 的部分一定最小,考虑怎么让绝对值函数的值尽可能地小。

    • \(j=i-w \in [L,R]\),那么此时绝对值函数的值可以取到 \(0\),因此有 \(F_u(i)=F_v(L)\)

    • \(j=i-w < L\),那么此时 \(\lvert j-(i-w)\rvert= j-i+w\),因此有 \(F_u(i)=F_v(L)+L-i+w\)

    • \(j=i-w > R\),那么此时 \(\lvert j-(i-w)\rvert=-j+i-w\),因此有 \(F_u(i)=F_v(L)+i-R-w\)

整理后有:

\[F_u(i)= \begin{cases} F_v(i)+w & i <L \\ F_v(L)+L-i+w & L \le i < L+w \\ F_v(L) & L+w \le i \le R+w \\ F_v(L)+i-R-w & i > R+w \end{cases} \]

然后就是体现到 Slope Trick 上(假设第一段函数为 \(y=kx+b\)):

  • \(i <L\) 时:等价于 \(b\)\(w\)

  • \(L \le i < L+w\) 时:等价于在 \(L\) 处加入一段斜率为 \(-1\) 的直线;

  • \(L+w \le i \le R+w\) 时:等价于将 \([L,R]\) 平移到 \([L+w,R+w]\)

  • \(i > R+w\) 时:等价于将 \(R+w\) 之后的斜率都改为 \(1\)

于是上述操作对拐点集合的影响为:去掉 \(L\) 及之后的拐点,并加入 \(L+w\)\(R+w\)

那么怎么找 \(L,R\) 呢?注意到每一次合并后,斜率为正的函数有且仅有一段(\(R+w\)),因此对于一个有 \(k\) 个儿子的点,它的 \(L,R\) 就是弹出 \(k-1\) 个拐点后剩下的那两个。

统计答案的时候我们注意到 \(F_1(0)\) 就是所有边权和,然后统计每一个拐点的贡献即可。

代码

#include<bits/stdc++.h>
#define inf 0x3f3f3f3f
#define Inf (1ll<<60)
#define For(i,s,t) for(int i=s;i<=t;++i)
#define Down(i,s,t) for(int i=s;i>=t;--i)
#define bmod(x) ((x)>=p?(x)-p:(x))
#define lowbit(x) ((x)&(-(x)))
#define End {printf("NO\n");exit(0);}
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
inline void ckmx(int &x,int y){x=(x>y)?x:y;}
inline void ckmn(int &x,int y){x=(x<y)?x:y;}
inline void ckmx(ll &x,ll y){x=(x>y)?x:y;}
inline void ckmn(ll &x,ll y){x=(x<y)?x:y;}
inline int min(int x,int y){return x<y?x:y;}
inline int max(int x,int y){return x>y?x:y;}
inline ll min(ll x,ll y){return x<y?x:y;}
inline ll max(ll x,ll y){return x>y?x:y;}
inline int read(){
    register int x=0,f=1;
    char c=getchar();
    while(c<'0' || '9'<c) f=(c=='-')?-1:1,c=getchar();
    while('0'<=c && c<='9') x=(x<<1)+(x<<3)+c-'0',c=getchar();
    return x*f;
}
void write(int x){
    if(x>=10) write(x/10);
    putchar(x%10+'0');
}
const int N=6e5+100;
int n,m,ls[N],rs[N],dist[N],fa[N],d[N],rt[N],tot;
ll ans,val[N],len[N];
int merge(int x,int y){
    if(!x || !y) return x|y;
    if(val[x]<val[y]) swap(x,y);
    rs[x]=merge(rs[x],y);
    if(d[ls[x]]<d[rs[x]])
        swap(ls[x],rs[x]);
    dist[x]=dist[rs[x]]+1;
    return x;
}
void insert(int& p,ll x){
    ++tot,val[tot]=x;
    p=merge(p,tot);
}
void pop(int& p){
    p=merge(ls[p],rs[p]);
}
int main()
{
#if !ONLINE_JUDGE
    freopen("test.in","r",stdin);
    freopen("test.out","w",stdout);
#endif 
    dist[0]=-1;
    n=read(),m=read();
    For(i,2,n+m) fa[i]=read(),len[i]=read(),++d[fa[i]],ans+=len[i];
    Down(u,n+m,2){
        if(u<=n) while(d[u]>1) pop(rt[u]),--d[u];
        ll R=val[rt[u]];pop(rt[u]);
        ll L=val[rt[u]];pop(rt[u]);
        insert(rt[u],L+len[u]);
        insert(rt[u],R+len[u]);
        rt[fa[u]]=merge(rt[fa[u]],rt[u]);
    }
    while(d[1]) pop(rt[1]),--d[1];//这里不需要留 R
    while(rt[1]) ans-=val[rt[1]],pop(rt[1]);
    printf("%lld",ans);
    return 0;
}
posted @ 2025-09-18 19:21  XiaoZi_qwq  阅读(5)  评论(0)    收藏  举报