BZOJ 3991: [SDOI2015]寻宝游戏

随便选一个点当做根,跑每个点的深度(为了求LCA)d [ u ] ,和到根节点的距离(为了更新答案) l [ u ]

我们发现,由关键点和他们的LCA构成的虚树(其实就是忽略其他节点),由于还要回到原点,所以相当于是树的所有边权的2倍

怎么求?对于每一次标记,将所有的标记了的点按时间戳排序,那么答案就是一、二两点之间的距离,二、三点之间距离。。。。最后一个点和第一个点之间的距离的总和;

而两点之间的距离等于 l [ u ] + l [ v ] - 2 * l [ lca ( u , v ) ]

维护这些点时可以用set

#include<cstdio>
#include<iostream>
#include<set>
#include<cmath>
#define getchar() *S++
#define ll long long 
#define R register int
const int N=100010,Inf=0x3f3f3f3f;
char RR[100000005],*S=RR;
using namespace std;
inline int g() {
    R ret=0; register char ch; while(!isdigit(ch=getchar())); 
    do ret=ret*10+(ch^48); while(isdigit(ch=getchar())); return ret;
}
int n,m,num,cnt;
int vr[N<<1],nxt[N<<1],w[N<<1],fir[N],dfn[N],rw[N],d[N],f[N][18];
set<int>s;
set<int>::iterator it;
ll ans,l[N];
bool vis[N];
inline void add(int u,int v,int ww) {vr[++cnt]=v,w[cnt]=ww,nxt[cnt]=fir[u],fir[u]=cnt;}
void dfs(int u) { dfn[u]=++num,rw[num]=u;
    for(R i=fir[u];i;i=nxt[i]) { R v=vr[i];
        if(d[v]) continue; d[v]=d[u]+1; l[v]=l[u]+w[i]; f[v][0]=u; R p=u;
        for(R j=0;f[p][j];++j) f[v][j+1]=f[p][j],p=f[p][j];
        dfs(v);
    }
}
inline int lca(int u,int v) {
    if(d[u]<d[v]) swap(u,v); R lim=log2(d[u])+1;
    for(R j=lim;j>=0;--j) if(d[f[u][j]]>=d[v]) u=f[u][j];
    if(u==v) return u;
    for(R j=lim;j>=0;--j) if(f[u][j]!=f[v][j]) u=f[u][j],v=f[v][j];
    return f[u][0];
}
inline ll dis(int u,int v) {return l[u]+l[v]-2*l[lca(u,v)];}
signed main() {
    fread(RR,sizeof(RR),1,stdin);
    n=g(),m=g();
    for(R i=1,u,v,w;i<n;++i) u=g(),v=g(),w=g(),add(u,v,w),add(v,u,w);
    d[1]=1;dfs(1); 
    s.insert(-Inf),s.insert(Inf);
    for(R i=1,x;i<=m;++i) { x=g(); register long long flg; 
        if(vis[x]) s.erase(dfn[x]),flg=-1; else s.insert(dfn[x]),flg=1; vis[x]^=1;
        it=s.upper_bound(dfn[x]); R r=*it,l=*(--it); if(l>=dfn[x]) l=*(--it);
        //cout<<l<<" "<<dfn[x]<<" "<<r<<endl;
        if(l!=-Inf) ans+=flg*dis(rw[l],x); if(r!=Inf) ans+=flg*dis(x,rw[r]);
        if(l!=-Inf&&r!=Inf) ans-=flg*dis(rw[l],rw[r]); 
        register long long tmp=(s.size()>3)?dis(rw[*++s.begin()],rw[*--(--s.end())]):0;
        printf("%lld\n",ans+tmp);
    }
}

2019.04.18

posted @ 2019-04-18 21:42  LuitaryiJack  阅读(89)  评论(0编辑  收藏  举报