【BZOJ-3910】火车 倍增LCA + 并查集

3910: 火车

Time Limit: 20 Sec  Memory Limit: 512 MB
Submit: 262  Solved: 90
[Submit][Status][Discuss]

Description

A 国有n 个城市,城市之间有一些双向道路相连,并且城市两两之间有唯一路径。现在有火车在城市 a,需要经过m 个城市。火车按照以下规则行驶:每次行驶到还没有经过的城市中在 m 个城市中最靠前的。现在小 A 想知道火车经过这m 个城市后所经过的道路数量。 

Input

第一行三个整数 n、m、a,表示城市数量、需要经过的城市数量,火车开始时所在位置。 
接下来 n-1 行,每行两个整数 x和y,表示 x 和y之间有一条双向道路。 
接下来一行 m 个整数,表示需要经过的城市。 

Output

一行一个整数,表示火车经过的道路数量。 

Sample Input

5 4 2
1 2
2 3
3 4
4 5
4 3 1 5

Sample Output

9

HINT

N<=500000 ,M<=400000 

Source

Solution

水题- -最多算个并查集的有趣应用

很显然直接询问用LCA统计答案即可

至于处理走过的路径,拿并查集维护一下,很简单的把起止点到LCA的点合并一下,询问的两个点如果属于一个集合显然走过

Code

#include<iostream>
#include<cmath>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
int read()
{
    int x=0,f=1; char ch=getchar();
    while (ch<'0' || ch>'9') {if (ch=='-') f=-1; ch=getchar();}
    while (ch>='0' && ch<='9') {x=x*10+ch-'0'; ch=getchar();}
    return x*f;
}
#define maxn 501000
int n,m,a;
struct EdgeNode{int to,next;}edge[maxn<<1];
int head[maxn],cnt;
void add(int u,int v) {cnt++; edge[cnt].next=head[u]; head[u]=cnt; edge[cnt].to=v;}
void insert(int u,int v) {add(u,v); add(v,u);}
int deep[maxn],father[maxn][20],ffff[maxn];
long long ans;
void dfs(int now)
{
    for (int i=1; i<=19; i++)
        if (deep[now]>=(1<<i))
            father[now][i]=father[father[now][i-1]][i-1];
        else
            break;
    for (int i=head[now]; i; i=edge[i].next)
        if (edge[i].to!=father[now][0])
            {
                deep[edge[i].to]=deep[now]+1;
                father[edge[i].to][0]=now;
                dfs(edge[i].to);
            }
}
int LCA(int x,int y)
{
    if (deep[x]<deep[y]) swap(x,y);
    int dd=deep[x]-deep[y];
    for (int i=0; (1<<i)<=dd; i++)
        if (dd&(1<<i)) x=father[x][i];
    for (int i=19; i>=0; i--)
        if (father[x][i]!=father[y][i])
            x=father[x][i],y=father[y][i];
    if (x==y) return x;
    return father[x][0];
}
int find(int x) {if (ffff[x]==x) return x; ffff[x]=find(ffff[x]); return ffff[x];}
int ff1,ff2;
int main()
{
    n=read(),m=read(),a=read();
    for (int u,v,i=1; i<=n-1; i++)
        u=read(),v=read(),insert(u,v);
    dfs(1);
    for (int i=1; i<=n; i++) ffff[i]=i;
    for (int i=1; i<=m; i++)
        {
            int x=read();
            int fa=find(a),fx=find(x); 
            if (fa==fx) continue;
            int lca=LCA(a,x);
            ans+=deep[a]-deep[lca]+deep[x]-deep[lca];
            int ta=a,tx=x,flca; flca=find(lca);
            while (find(ta)!=flca) {ff1=find(ta); ffff[ff1]=flca; ta=father[ff1][0];}
            while (find(tx)!=flca) {ff2=find(tx); ffff[ff2]=flca; tx=father[ff2][0];}
            a=x;
        }
    cout<<ans;
    return 0;
}

电脑炸出奇怪的错误,所以写的比较鬼畜- -

posted @ 2016-06-18 11:21  DaD3zZ  阅读(292)  评论(0编辑  收藏