[bzoj3910] 火车

  一开始看错题了...

  求经过的道路数量就求个lca,把路径上的点置为已经过的话,可以写一个并查集,把某个点往上连续已经过的一段点并起来。

 1 #include<cstdio>
 2 #include<iostream>
 3 #include<cstring>
 4 #include<algorithm>
 5 #include<cmath>
 6 #define d double
 7 #define ll long long
 8 using namespace std;
 9 const int maxn=500233;
10 struct zs{int too,pre;}e[maxn<<1];int tot,last[maxn];
11 int dep[maxn],fa[maxn],dfn[maxn],bel[maxn],sz[maxn],tim,f[maxn];
12 bool gg[maxn];
13 int i,j,k,n,m;
14 ll ans;
15  
16 int ra;char rx;
17 inline int read(){
18     rx=getchar(),ra=0;
19     while(rx<'0'||rx>'9')rx=getchar();
20     while(rx>='0'&&rx<='9')ra*=10,ra+=rx-48,rx=getchar();return ra;
21 }
22 void dfs(int x){
23     sz[x]=1,dep[x]=dep[fa[x]]+1;
24     for(int i=last[x];i;i=e[i].pre)if(e[i].too!=fa[x])
25         fa[e[i].too]=x,dfs(e[i].too),sz[x]+=sz[e[i].too];
26 }
27 void dfs2(int x,int chain){
28     int i,mx=0;
29     bel[x]=chain,dfn[x]=++tim;
30     for(i=last[x];i;i=e[i].pre)if(e[i].too!=fa[x]&&sz[e[i].too]>sz[mx])mx=e[i].too;
31     if(!mx)return;
32     dfs2(mx,chain);
33     for(i=last[x];i;i=e[i].pre)if(e[i].too!=fa[x]&&e[i].too!=mx)dfs2(e[i].too,e[i].too);
34 }
35 inline int getlca(int a,int b){
36     while(bel[a]!=bel[b]){
37         if(dep[bel[a]]<dep[bel[b]])swap(a,b);
38         a=fa[bel[a]];
39     }
40     return dep[a]<dep[b]?a:b;
41 }
42  
43  
44 inline int getf(int x){return f[x]!=x?f[x]=getf(f[x]):x;}
45 inline void run(int x,int lca){
46     lca=getf(lca);//printf("lca:   %d\n",lca);
47     while(getf(x)!=lca)
48         x=f[x],f[x]=lca,x=fa[x];
49     gg[lca]=1;
50 }
51  
52 inline void insert(int a,int b){
53     e[++tot].too=b,e[tot].pre=last[a],last[a]=tot,
54     e[++tot].too=a,e[tot].pre=last[b],last[b]=tot;
55 }
56 int main(){int st,lca;
57     n=read(),m=read(),st=read();
58     for(i=1;i<n;i++)insert(read(),read());
59     for(i=1;i<=n;i++)f[i]=i;
60     dfs(1),dfs2(1,1);
61     for(i=1;i<=m;i++){
62         j=read();
63         if(gg[getf(j)])continue;
64         lca=getlca(st,j),run(st,lca),run(j,lca),
65         ans+=dep[j]+dep[st]-(dep[lca]<<1),st=j;
66 //      for(printf("done:  "),j=1;j<=n;j++)if(gg[getf(j)])printf("   %d",j);puts("");
67     }
68     printf("%lld\n",ans);
69 }
View Code

 

posted @ 2016-07-05 19:04  czllgzmzl  阅读(457)  评论(0编辑  收藏  举报