BZOJ 2125: 最短路 (仙人掌,树链剖分)
第一道仙人掌题.
由于仙人掌中每条边最多只属于一个环,所以两个在环中的点的最短距离是好算的.
code:
#include <bits/stdc++.h>
#define N 200006
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
int edges,n,m,Q,tim,tot;
struct Edge
{
int v,c;
Edge(int v=0,int c=0):v(v),c(c){}
};
vector<Edge>G[N];
int hd[N],to[N<<1],nex[N<<1],fa[N],low[N],dfn[N],val[N],dis[N],cdis[N];
int dep[N],size[N],top[N],son[N],point[N];
void add(int u,int v,int c)
{
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v,val[edges]=c;
}
void build(int ff,int x,int c)
{
for(int i=x;i!=ff;i=fa[i])
cdis[i]=c,c+=dis[i];
cdis[++tot]=c;
G[tot].push_back(Edge(ff,0));
G[ff].push_back(Edge(tot,0));
for(int i=x;i!=ff;i=fa[i])
{
G[tot].push_back(Edge(i,min(cdis[i],c-cdis[i])));
G[i].push_back(Edge(tot,min(cdis[i],c-cdis[i])));
}
}
void tarjan(int x,int ff)
{
fa[x]=ff;
low[x]=dfn[x]=++tim;
for(int i=hd[x];i;i=nex[i])
{
int y=to[i];
if(y==ff) continue;
if(!dfn[y])
dis[y]=val[i],tarjan(y,x),low[x]=min(low[x],low[y]);
else low[x]=min(low[x],dfn[y]);
if(low[y]>dfn[x])
{
G[x].push_back(Edge(y,val[i]));
G[y].push_back(Edge(x,val[i]));
}
}
for(int i=hd[x];i;i=nex[i])
{
int y=to[i];
if(fa[y]!=x&&dfn[y]>dfn[x])
build(x,y,val[i]);
}
}
void dfs1(int x,int ff)
{
fa[x]=ff;
for(int i=0;i<G[x].size();++i)
{
int y=G[x][i].v,c=G[x][i].c;
if(y==ff)
continue;
dis[y]=dis[x]+c;
dep[y]=dep[x]+1;
dfs1(y,x);
size[x]+=size[y];
if(size[y]>size[son[x]])
son[x]=y;
}
}
void dfs2(int x,int tp)
{
top[x]=tp;
point[dfn[x]=++tim]=x;
if(son[x])
dfs2(son[x],tp);
for(int i=0;i<G[x].size();++i)
if(G[x][i].v!=fa[x]&&G[x][i].v!=son[x])
dfs2(G[x][i].v,G[x][i].v);
}
int get_lca(int x,int y)
{
while(top[x]!=top[y])
dep[top[x]]>dep[top[y]]?x=fa[top[x]]:y=fa[top[y]];
return dep[x]<dep[y]?x:y;
}
int jump(int x,int lca)
{
int las;
while(top[x]!=top[lca])
las=top[x],x=fa[top[x]];
return x==lca?las:point[dfn[lca]+1];
}
int main()
{
// setIO("input");
scanf("%d%d%d",&n,&m,&Q);
for(int i=1;i<=m;++i)
{
int u,v,c;
scanf("%d%d%d",&u,&v,&c);
add(u,v,c),add(v,u,c);
}
tot=n,tarjan(1,0);
dis[1]=dep[1]=tim=0;
dfs1(1,0);
dfs2(1,1);
for(int i=1;i<=Q;++i)
{
int x,y;
scanf("%d%d",&x,&y);
int lca=get_lca(x,y);
if(lca<=n)
printf("%d\n",dis[x]+dis[y]-(dis[lca]<<1));
else
{
int p1=jump(x,lca);
int p2=jump(y,lca);
int ans=dis[x]-dis[p1]+dis[y]-dis[p2]+min(abs(cdis[p2]-cdis[p1]),cdis[lca]-abs(cdis[p2]-cdis[p1]));
printf("%d\n",ans);
}
}
return 0;
}

浙公网安备 33010602011771号