BZOJ 3910: 火车

3910: 火车

Time Limit: 20 Sec  Memory Limit: 512 MB
Submit: 358  Solved: 130
[Submit][Status][Discuss]

Description

A 国有n 个城市,城市之间有一些双向道路相连,并且城市两两之间有唯一
路径。现在有火车在城市 a,需要经过m 个城市。火车按照以下规则行驶:每次
行驶到还没有经过的城市中在 m 个城市中最靠前的。现在小 A 想知道火车经过
这m 个城市后所经过的道路数量。 

Input

第一行三个整数 n、m、a,表示城市数量、需要经过的城市数量,火车开始
时所在位置。 
接下来 n-1 行,每行两个整数 x和y,表示 x 和y之间有一条双向道路。 
接下来一行 m 个整数,表示需要经过的城市。 

Output

一行一个整数,表示火车经过的道路数量。 

Sample Input

5 4 2
1 2
2 3
3 4
4 5
4 3 1 5

Sample Output

9

HINT

N<=500000 ,M<=400000 

Source

分析:

在经历了无数次对题意的错误理解(和被YouSiki的无数次吐槽)之后,我理解对了题意,然后觉得可能是个暴力...

事实证明这就是个暴力,我们每一次暴力标记访问过的节点,然后并查集维护,如果访问过就把这个点并到父亲上,均摊复杂度显然是$O(N)$的...

代码:

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
//by NeighThorn
using namespace std;

const int maxn=500000+5;

int n,m,st,cnt,a[maxn],f[maxn],hd[maxn],to[maxn<<1],fa[maxn][25],dep[maxn],nxt[maxn<<1];
long long ans;

inline int find(int x){
	return f[x]==x?x:f[x]=find(f[x]);
}

inline void add(int x,int y){
	to[cnt]=y;nxt[cnt]=hd[x];hd[x]=cnt++;
}

inline void dfs(int x){
	for(int i=hd[x];i!=-1;i=nxt[i])
		if(to[i]!=fa[x][0])
			fa[to[i]][0]=x,
			dep[to[i]]=dep[x]+1,dfs(to[i]);
}

inline void init(void){
	for(int j=1;j<=20;j++)
		for(int i=1;i<=n;i++)
			fa[i][j]=fa[fa[i][j-1]][j-1];
}

inline int LCA(int x,int y){
	if(dep[x]<dep[y])
		swap(x,y);
	int d=dep[x]-dep[y];
	for(int i=20;i>=0;i--)
		if((d>>i)&1)
			x=fa[x][i];
	if(x==y)
		return x;
	for(int i=20;i>=0;i--)
		if(fa[x][i]!=fa[y][i])
			x=fa[x][i],y=fa[y][i];
	return fa[x][0];
}

signed main(void){
	memset(hd,-1,sizeof(hd));
	scanf("%d%d%d",&n,&m,&st);
	for(int i=1,x,y;i<n;i++)
		scanf("%d%d",&x,&y),add(x,y),add(y,x);
	dep[1]=1,fa[1][0]=1,dfs(1);init();
	for(int i=1;i<=n;i++)
		f[i]=i;
	for(int i=1;i<=m;i++)
		scanf("%d",&a[i]);
	for(int i=1,x,y,lca;i<=m;i++){
		if(find(a[i])!=a[i])
			continue;
		x=st,y=a[i],lca=LCA(x,y);
		while(dep[x]>=dep[lca]){
			if(find(x)==x)
				f[x]=x==1?0:fa[x][0];
			else
				x=find(x);
		}
		while(dep[y]>=dep[lca]){
			if(find(y)==y)
				f[y]=y==1?0:fa[y][0];
			else
				y=find(y);
		}
		x=st,y=a[i],st=a[i];
		ans+=dep[x]-dep[lca];
		ans+=dep[y]-dep[lca];
	}
	printf("%lld\n",ans);
	return 0;
}

  


By NeighThorn

posted @ 2017-03-13 20:20  NeighThorn  阅读(313)  评论(0编辑  收藏  举报