寻宝游戏题解
寻宝游戏题解
这道题可以用虚树或者倍增+lca来求,这里使用的时倍增+lca。(树剖+lca个人打的超时了,不确定是因为什么)
首先我们需要知道我们要求的其实是一个联通图的路径和的二倍,可以画图理解一下。
而这张图其实是一棵树的联通子树,所以我们要求的其实就是这棵子树的路径和的二倍。
而这个路径和二倍,其实就是所有有宝藏的点组成一个序列\(A[1...K]\)(\(K\)为宝藏点的数量),将所有的\(dist(a_1,a_2)+dist(a_2,a_3)+...+dist(a_k,a_1)\)注意最后一个\(dist\)是\(dist(a_k,a_1)\),你可以把它看作一个环状的序列,但是编号怎么处理,直接使用dfs序即可,而我们加入点\(x\)时就找它dfs序的前驱和后继,如果这个点时边界点,就找序列的另一端(也就是环状序列的前驱后继):\(ans+=dist(x,y)+dist(x,z)-dist(z,y)\)而这个dist又如何求呢?
根据lca的性质\(dist(x,y)=dep(x)+dep(y)-2\times dep(lca)\)
需要注意的是为了求\(lca\)需要一个\(dep\)数组,而为了求\(dist\)需要另一个\(dep\)数组,两者的区别在于,第一个\(dep\)是根据深度每次\(+1\),而第二个是根据父节点和子节点的距离每次\(+dist(fa,son)\)。
代码
#include<bits/stdc++.h>
using namespace std;
const int MN=1e5+100;
int n,m,head[MN],cnt;
long long ans;
int vis[MN];
struct node{
int nxt,to,w;
}e[MN<<1];
inline void add(int a,int b,int w){e[++cnt].nxt=head[a],head[a]=cnt,e[cnt].to=b,e[cnt].w=w;}
int f[MN][20],dep[MN],dfn[MN],name[MN],tot;
long long dis[MN];
inline void dfs(int now,int pre){
f[now][0]=pre,dfn[now]=++tot,name[tot]=now,dep[now]=dep[pre]+1;
for(int i=1;1<<i<dep[now];++i)f[now][i]=f[f[now][i-1]][i-1];
for(int i=head[now];i;i=e[i].nxt){
int to=e[i].to,w=e[i].w;
if(to==pre)continue;
dis[to]=dis[now]+w;
dfs(to,now);
}
}
inline int LCA(int a,int b){
if(dep[a]<dep[b])swap(a,b);
for(int i=16;i>=0;--i){
if(dep[f[a][i]]>=dep[b])a=f[a][i];
}
if(a==b)return a;
for(int i=16;i>=0;--i){
if(f[a][i]!=f[b][i]){
a=f[a][i];
b=f[b][i];
}
}
return f[a][0];
}
long long dist(int a,int b){
return dis[a]+dis[b]-(dis[LCA(a,b)]<<1);
}
set<int> st;
set<int>::iterator it;
int main(){
freopen("game.in","r",stdin);
freopen("game.out","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1,x,y,z;i<n;++i){
scanf("%d%d%d",&x,&y,&z);
add(x,y,z),add(y,x,z);
}
dfs(1,0);
int x,y,z;
while(m--){
scanf("%d",&x);
x=dfn[x];
if(!vis[name[x]])st.insert(x);
y=name[(it=st.lower_bound(x))==st.begin()?*--st.end():*--it];
z=name[(it=st.upper_bound(x))==st.end()?*st.begin():*it];
if(vis[name[x]])st.erase(x);
x=name[x];
long long d=dist(x,y)+dist(x,z)-dist(z,y);
if(!vis[x])vis[x]=1,ans+=d;
else vis[x]=0,ans-=d;
printf("%lld\n",ans);
}
return 0;
}