我与 P2680运输计划 的故事

题目链接

这是一个很长的故事。

2021年9月,我因为填错答题卡未过初赛………………

收到成绩的第二天,我抱着沉痛的心理去上了最后一节OI课

我随后点开来这道题。

那时我图个好玩,写了一个超级暴力:

(正解是lca+二分,但是我直接上全源最短路,十分的玄学)

#include<cstdio>
#include<queue>
#include<algorithm>
#define N 500005
#define int long long
#define clear(arr,b) for(int i=1;i<N;i++)arr[i]=b
using namespace std;
int head[N],to[N],nxt[N],val[N],n,m,s,dis[N],tot;
bool vis[N];
void add(int u,int v,int w){
    to[++tot]=v;
    nxt[tot]=head[u];
    head[u]=tot;
    val[tot]=w;
}
void dijkstra(int s){
    priority_queue<pair<int ,int > >q;
    clear(dis,2147483647);
    clear(vis,0);
    dis[s]=0;
    q.push(make_pair(0,s));
    while(!q.empty()){
        int x=q.top().second;q.pop();
        if(vis[x])continue;
        vis[x]=1;
        for(int i=head[x];i;i=nxt[i]){
            int y=to[i],w=val[i];
            if(dis[y]>dis[x]+w){
                dis[y]=dis[x]+w;
                if(!vis[y])q.push(make_pair(-dis[y],y));
            }
        }
    }
}
signed main(){
    scanf("%lld%lld",&n,&m);
    for(int i=1;i<n;i++){
        int a,b,c;
        scanf("%lld%lld%lld",&a,&b,&c);
        add(a,b,c);
        add(b,a,c);
    }
    int ans=0x7ffff;
    for(int i=1;i<=m;i++){
        int a=0,b=0;scanf("%lld%lld",&a,&b);
        dijkstra(a);
        ans=min(dis[b],ans);
    }
    printf("%lld\n",ans);
    return 0;
}

最终我成功骗到了五分。

随后的日子我日益消沉,几乎没怎么碰OI。

这道题也就石沉大海,很久没有翻出来过了。

半年过去了,我因为疫情被迫在家上网课。

于是开始自学起了OI。

从tarjan缩点一直学到Splay。

没有停下来的意思,但是开学了,马上要中考了。

我还是沉迷OI,在学校模拟考也不算太差(是可以直升的分数),也不太在意。

但是中考还是来了,抱着一切希望,去拼了一把,但是落榜了。

去了一个不那么好的学校,阴差阳错的又点开了这道题。

于是我又开始写了。

#include<cstdio>
#include<algorithm>
#define N 600006
using namespace std;
int tot,n,m,head[N],to[N],nxt[N],val[N];
int dep[N],f[N][34],g[N][34],lg[N],dis[N],l,r,v[N];
struct node{
    int u,v,lca,mx,dis;
}a[N];
void add(int u,int v,int w){
    to[++tot]=v;
    nxt[tot]=head[u];
    head[u]=tot;
    val[tot]=w;
}
inline void csh(){for(int i=1;i<=n;i++)lg[i]=lg[i>>1]+1;}
void rep(int x,int fa){
    dep[x]=dep[fa]+1;
    f[x][0]=fa;
    for(int i=1;i<=lg[dep[x]];i++){
        f[x][i]=f[f[x][i-1]][i-1];
        g[x][i]=max(g[x][i],g[f[x][i-1]][i-1]);
    }
    for(int i=head[x];i;i=nxt[i]){
        int y=to[i];
        if(y==fa)continue;
        g[y][0]=val[i];
        v[y]=val[i];
        dis[y]=dis[x]+val[i];
        rep(y,x);
    }
}
pair<int,int> lca(int x,int y){
    if(dep[x]<dep[y])swap(x,y);
    int res=0;
    while(dep[x]!=dep[y]){
        int i=lg[dep[x]-dep[y]];
        res=max(res,g[x][i]);
        x=f[x][i];
    }
    if(x==y)return {res,x};
    for(int i=lg[dep[x]];i>=0;i--){
        if(f[x][i]!=f[y][i]){
            res=max(res,g[x][i]);
            res=max(res,g[y][i]);
            x=f[x][i];y=f[y][i];
        }
    }
    return {res,f[x][0]};
}
bool check(int mid){
    int mx=0;
    for(int i=1;i<=m;i++)mx=max(mx,a[i].dis-a[i].mx);
    if(mx>mid)return 0;
    else return 1;
}
signed main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;i++){
        int a,b,c;
        scanf("%d%d%d",&a,&b,&c);
        add(a,b,c);
        add(b,a,c);
        r+=c;
    }
    rep(1,0);
    for(int i=1;i<=m;i++){
        scanf("%d%d",&a[i].u,&a[i].v);
        pair<int,int> ans=lca(a[i].u,a[i].v);
        a[i].mx=ans.first;
        a[i].lca=ans.second;
        a[i].dis=dis[a[i].u]+dis[a[i].v]-2*dis[a[i].lca];
    }
    while(l<r){
        int mid=(l+r)/2;
        if(!check(mid))l=mid+1;
        else r=mid; 
    }
    printf("%d",l);
    return 0;
}

这个代码可以得45pts,但是我的lg甚至都没有初始化,如果我打的是2015年NOIP,那么我就可以白捡45分。

这不是在考场,我继续思考。

于是我把初始化加上了:

#include<cstdio>
#include<algorithm>
#define N 600006
using namespace std;
int tot,n,m,head[N],to[N],nxt[N],val[N];
int dep[N],f[N][34],g[N][34],lg[N],dis[N],l,r,v[N];
struct node{
    int u,v,lca,mx,dis;
}a[N];
void add(int u,int v,int w){
    to[++tot]=v;
    nxt[tot]=head[u];
    head[u]=tot;
    val[tot]=w;
}
inline void csh(){for(int i=2;i<=n;i++)lg[i]=lg[i>>1]+1;}
void rep(int x,int fa){
    dep[x]=dep[fa]+1;
    f[x][0]=fa;
    for(int i=1;i<=lg[dep[x]];i++){
        f[x][i]=f[f[x][i-1]][i-1];
        g[x][i]=max(g[x][i],g[f[x][i-1]][i-1]);
    }
    for(int i=head[x];i;i=nxt[i]){
        int y=to[i];
        if(y==fa)continue;
        g[y][0]=val[i];
        v[y]=val[i];
        dis[y]=dis[x]+val[i];
        rep(y,x);
    }
}
pair<int,int> lca(int x,int y){
    if(dep[x]<dep[y])swap(x,y);
    int res=0;
    while(dep[x]!=dep[y]){
        int i=lg[dep[x]-dep[y]];
        res=max(res,g[x][i]);
        x=f[x][i];
    }
    if(x==y)return {res,x};
    for(int i=lg[dep[x]];i>=0;i--){
        if(f[x][i]!=f[y][i]){
            res=max(res,g[x][i]);
            res=max(res,g[y][i]);
            x=f[x][i];y=f[y][i];
        }
    }
    return {res,f[x][0]};
}
bool check(int mid){
    int mx=0;
    for(int i=1;i<=m;i++)mx=max(mx,a[i].dis-a[i].mx);
    if(mx>mid)return 0;
    else return 1;
}
signed main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;i++){
        int a,b,c;
        scanf("%d%d%d",&a,&b,&c);
        add(a,b,c);
        add(b,a,c);
        r+=c;
    }
    csh();
    rep(1,0);
    for(int i=1;i<=m;i++){
        scanf("%d%d",&a[i].u,&a[i].v);
        pair<int,int> ans=lca(a[i].u,a[i].v);
        a[i].mx=ans.first;
        a[i].lca=ans.second;
        a[i].dis=dis[a[i].u]+dis[a[i].v]-2*dis[a[i].lca];
    }
    while(l<r){
        int mid=(l+r)/2;
        if(!check(mid))l=mid+1;
        else r=mid; 
    }
    printf("%d",l);
    return 0;
}

搞笑的是,只能得15分了。

于是我重构了check函数,但是还是只能得15分,我很愤怒。

#include<cstdio>
#include<algorithm>
#define N 600006
#define int long long
using namespace std;
int tot,n,m,head[N],to[N],nxt[N],val[N];
int dep[N],f[N][34],g[N][34],lg[N],dis[N],l,r;
struct node{
    int u,v,lca,mx,dis;
}a[N];
void add(int u,int v,int w){
    to[++tot]=v;
    nxt[tot]=head[u];
    head[u]=tot;
    val[tot]=w;
}
inline void csh(){for(int i=2;i<=n;i++)lg[i]=lg[i>>1]+1;}
void rep(int x,int fa){
    dep[x]=dep[fa]+1;
    f[x][0]=fa;
    for(int i=1;i<=lg[dep[x]];i++){
        f[x][i]=f[f[x][i-1]][i-1];
        g[x][i]=max(g[x][i],g[f[x][i-1]][i-1]);
    }
    for(int i=head[x];i;i=nxt[i]){
        int y=to[i];
        if(y==fa)continue;
        g[y][0]=val[i];
        dis[y]=dis[x]+val[i];
        rep(y,x);
    }
}
pair<int,int> lca(int x,int y){
    if(dep[x]<dep[y])swap(x,y);
    int res=0;
    while(dep[x]!=dep[y]){
        int i=lg[dep[x]-dep[y]];
        res=max(res,g[x][i]);
        x=f[x][i];
    }
    if(x==y)return {res,x};
    for(int i=lg[dep[x]];i>=0;i--){
        if(f[x][i]!=f[y][i]){
            res=max(res,g[x][i]);
            res=max(res,g[y][i]);
            x=f[x][i];y=f[y][i];
        }
    }
    return {res,f[x][0]};
}
bool check(int mid){
    int mx=0,sec=0,p=0;
    for(int i=1;i<=m;i++){
        if(a[i].dis>sec&&a[i].dis<mx)sec=mx;
        if(a[i].dis>mx)sec=mx,mx=a[i].dis,p=i;
        if(a[i].dis==mx&&a[i].mx>a[p].mx)sec=mx,mx=a[i].dis,p=i;
    }
    mx=max(sec,mx-a[p].mx);
    if(mx>mid)return 0;
    else return 1;
}
signed main(){
    scanf("%lld%lld",&n,&m);
    for(int i=1;i<n;i++){
        int a,b,c;
        scanf("%lld%lld%lld",&a,&b,&c);
        add(a,b,c);
        add(b,a,c);
        r+=c;
    }
    csh();
    rep(1,0);
    for(int i=1;i<=m;i++){
        scanf("%lld%lld",&a[i].u,&a[i].v);
        pair<int,int> ans=lca(a[i].u,a[i].v);
        a[i].mx=ans.first;
        a[i].lca=ans.second;
        a[i].dis=dis[a[i].u]+dis[a[i].v]-2*dis[a[i].lca];
        //printf("%d %d %d\n",a[i].mx,a[i].lca,a[i].dis);
    }
    while(l<r){
        int mid=(l+r)/2;
        if(!check(mid))l=mid+1;
        else r=mid; 
    }
    printf("%lld",l);
    return 0;
}

最后我仔细思考,先进行树上差分,再前缀和,统计出链被经过的次数,当一个点被经过的次数大于等于所在链总dis超过mid的链的个数时,说明这个mid还有救,再判断一下这个点到他父节点的距离就行了,如果这个距离减去之后答案小于等于mid,那么这个check就成立了。

bool check(int mid){
    int cnt=0,del=0;
    for(int i=1;i<=n;i++)tmp[i]=0;
    for(int i=1;i<=m;i++){
        if(a[i].dis>mid){
            tmp[a[i].u]++,tmp[a[i].v]++,tmp[a[i].lca]-=2;
            del=max(del,a[i].dis-mid);
            cnt++;
        }
    }
    if(!cnt)return 1;
    for(int i=n;i>=1;i--)tmp[f[dfn[i]][0]]+=tmp[dfn[i]];
    for(int i=2;i<=n;i++)if(tmp[i]==cnt&&dis[i]-dis[f[i][0]]>=del)return 1;
    return 0;
}

最终,check函数长这样,各位可以自学体会一下,实际上就是树上的种种操作。

最后我成功AC了本题

#include<cstdio>
#include<algorithm>
#define N 600006
#define int long long
using namespace std;
int tot,n,m,head[N],to[N],nxt[N],val[N];
int dep[N],f[N][34],g[N][34],lg[N],dis[N],l,r,dfn[N],cnt;
int tmp[N];
struct node{
    int u,v,lca,mx,dis;
}a[N];
void add(int u,int v,int w){
    to[++tot]=v;
    nxt[tot]=head[u];
    head[u]=tot;
    val[tot]=w;
}
inline void csh(){for(int i=2;i<=n;i++)lg[i]=lg[i>>1]+1;}
void rep(int x,int fa){
    dfn[++cnt]=x;
    dep[x]=dep[fa]+1;
    f[x][0]=fa;
    for(int i=1;i<=lg[dep[x]];i++){
        f[x][i]=f[f[x][i-1]][i-1];
        g[x][i]=max(g[x][i],g[f[x][i-1]][i-1]);
    }
    for(int i=head[x];i;i=nxt[i]){
        int y=to[i];
        if(y==fa)continue;
        g[y][0]=val[i];
        dis[y]=dis[x]+val[i];
        rep(y,x);
    }
}
pair<int,int> lca(int x,int y){
    if(dep[x]<dep[y])swap(x,y);
    int res=0;
    while(dep[x]!=dep[y]){
        int i=lg[dep[x]-dep[y]];
        res=max(res,g[x][i]);
        x=f[x][i];
    }
    if(x==y)return {res,x};
    for(int i=lg[dep[x]];i>=0;i--){
        if(f[x][i]!=f[y][i]){
            res=max(res,g[x][i]);
            res=max(res,g[y][i]);
            x=f[x][i];y=f[y][i];
        }
    }
    return {res,f[x][0]};
}
bool check(int mid){
    int cnt=0,del=0;
    for(int i=1;i<=n;i++)tmp[i]=0;
    for(int i=1;i<=m;i++){
        if(a[i].dis>mid){
            tmp[a[i].u]++,tmp[a[i].v]++,tmp[a[i].lca]-=2;
            del=max(del,a[i].dis-mid);
            cnt++;
        }
    }
    if(cnt==0)return 1;
    for(int i=n;i>=1;i--)tmp[f[dfn[i]][0]]+=tmp[dfn[i]];
    for(int i=2;i<=n;i++)if(tmp[i]==cnt&&dis[i]-dis[f[i][0]]>=del)return 1;
    return 0;
}
signed main(){
    scanf("%lld%lld",&n,&m);
    for(int i=1;i<n;i++){
        int a,b,c;
        scanf("%lld%lld%lld",&a,&b,&c);
        add(a,b,c);
        add(b,a,c);
        r+=c;
    }
    csh();
    rep(1,0);
    for(int i=1;i<=m;i++){
        scanf("%lld%lld",&a[i].u,&a[i].v);
        pair<int,int> ans=lca(a[i].u,a[i].v);
        a[i].mx=ans.first;
        a[i].lca=ans.second;
        a[i].dis=dis[a[i].u]+dis[a[i].v]-2*dis[a[i].lca];
        //printf("%d %d %d\n",a[i].mx,a[i].lca,a[i].dis);
    }
    while(l<r){
        int mid=(l+r)/2;
        if(!check(mid))l=mid+1;
        else r=mid; 
    }
    printf("%lld",l);
    return 0;
}

发现最大值不需要了,所以拿掉,顺便整理一下代码:

#include<cstdio>
#include<algorithm>
#define N 600006
#define int long long
using namespace std;
int tot,n,m,head[N],to[N],nxt[N],val[N];
int dep[N],f[N][34],lg[N],dis[N],l,r,dfn[N],cnt;
int tmp[N];
struct node{
    int u,v,lca,mx,dis;
}a[N];
void add(int u,int v,int w){
    to[++tot]=v;
    nxt[tot]=head[u];
    head[u]=tot;
    val[tot]=w;
}
inline void csh(){for(int i=2;i<=n;i++)lg[i]=lg[i>>1]+1;}
void rep(int x,int fa){
    dfn[++cnt]=x;
    dep[x]=dep[fa]+1;
    f[x][0]=fa;
    for(int i=1;i<=lg[dep[x]];i++)f[x][i]=f[f[x][i-1]][i-1];
    for(int i=head[x];i;i=nxt[i]){
        int y=to[i];
        if(y==fa)continue;
        dis[y]=dis[x]+val[i];
        rep(y,x);
    }
}
int lca(int x,int y){
    if(dep[x]<dep[y])swap(x,y);
    while(dep[x]!=dep[y]){
        int i=lg[dep[x]-dep[y]];
        x=f[x][i];
    }
    if(x==y)return x;
    for(int i=lg[dep[x]];i>=0;i--)if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];  
    return f[x][0];
}
bool check(int mid){
    int cnt=0,del=0;
    for(int i=1;i<=n;i++)tmp[i]=0;
    for(int i=1;i<=m;i++){
        if(a[i].dis>mid){
            tmp[a[i].u]++,tmp[a[i].v]++,tmp[a[i].lca]-=2;
            del=max(del,a[i].dis-mid);
            cnt++;
        }
    }
    if(cnt==0)return 1;
    for(int i=n;i>=1;i--)tmp[f[dfn[i]][0]]+=tmp[dfn[i]];
    for(int i=2;i<=n;i++)if(tmp[i]==cnt&&dis[i]-dis[f[i][0]]>=del)return 1;
    return 0;
}
signed main(){
    scanf("%lld%lld",&n,&m);
    for(int i=1;i<n;i++){
        int a,b,c;
        scanf("%lld%lld%lld",&a,&b,&c);
        add(a,b,c);
        add(b,a,c);
        r+=c;
    }
    csh();
    rep(1,0);
    for(int i=1;i<=m;i++){
        scanf("%lld%lld",&a[i].u,&a[i].v);
        int ans=lca(a[i].u,a[i].v);
        a[i].lca=ans;
        a[i].dis=dis[a[i].u]+dis[a[i].v]-2*dis[ans];
    }
    while(l<r){
        int mid=(l+r)/2;
        if(!check(mid))l=mid+1;
        else r=mid; 
    }
    printf("%lld",l);
    return 0;
}

很好,作为SFLS前最惨OIer,希望天下没有OIer和我一样惨。

完结撒花吧,虽然挺难受的。

posted @ 2022-08-22 16:06  灵长同志  阅读(37)  评论(0)    收藏  举报