bzoj2750Road——最短路计数

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=2750

以每个点作为源点,spfa跑出一个最短路图(不一定是树,因为可能很多条最短路一样长);

对于图中的每条边,需要知道从源点到边起点的方案数和边终点的size;

所以对于每张图都dfs求一遍所有点的两个值:从源点到它的方案数(a),它以下的size(b);

由于不能破坏原图,所以可以通过dis[edge[i].to]==dis[edge[i].hd]+edge[i].w来判断这条边是否在最短路图中。

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
queue<int>q;
int const MAXN=1505,MAXM=5005,P=1e9+7;
int n,m,head[MAXN],ct,dis[MAXN],p[MAXN];
long long ans[MAXM],a[MAXN],b[MAXN];
bool vis[MAXN],f[MAXN];
struct N{
    int hd,to,next,w;
    N(int h=0,int t=0,int n=0,int w=0):hd(h),to(t),next(n),w(w) {}
}edge[MAXM];
void spfa(int s)
{
    memset(vis,0,sizeof vis);
    memset(dis,0x3f,sizeof dis);
    while(q.size())q.pop();
    q.push(s);dis[s]=0;vis[s]=1;
    while(q.size())
    {
        int x=q.front();q.pop();vis[x]=0;
        for(int i=head[x];i;i=edge[i].next)
        {
            int u=edge[i].to;
            if(dis[u]>dis[x]+edge[i].w)
            {
                dis[u]=dis[x]+edge[i].w;
                if(!vis[u])vis[u]=1,q.push(u);
            }
        }
    }
}
void dfs1(int x)
{
    f[x]=1;
    for(int i=head[x],u;i;i=edge[i].next)
        if(dis[u=edge[i].to]==dis[x]+edge[i].w)
        {
            p[u]++;
            if(!f[u])dfs1(u);
        }
}
void dfs2(int x)
{
    for(int i=head[x],u;i;i=edge[i].next)
        if(dis[u=edge[i].to]==dis[x]+edge[i].w)
        {
            a[u]=(a[u]+a[x])%P;p[u]--;//拓扑 
            if(!p[u])dfs2(u);
        }
}
void dfs3(int x)
{
    b[x]=1;
    for(int i=head[x],u;i;i=edge[i].next)
        if(dis[u=edge[i].to]==dis[x]+edge[i].w)
        {
            if(!b[u])dfs3(u);//记忆化
            b[x]=(b[x]+b[u])%P; 
        }
}
int main()
{
    scanf("%d%d",&n,&m);
    int x,y,z;
    for(int i=1;i<=m;i++)
    {
        scanf("%d%d%d",&x,&y,&z);
        edge[++ct]=N(x,y,head[x],z);head[x]=ct;
    }
    for(int i=1;i<=n;i++)
    {
        spfa(i);
        memset(f,0,sizeof f);
        memset(a,0,sizeof a);
        memset(b,0,sizeof b);
        dfs1(i);a[i]=1;dfs2(i);dfs3(i);
        for(int j=1,u,v;j<=m;j++)
            if(dis[v=edge[j].to]==dis[u=edge[j].hd]+edge[j].w)
                ans[j]=(ans[j]+a[u]*b[v])%P;
    }
    for(int j=1;j<=m;j++)
        printf("%lld\n",ans[j]);
    return 0;
}

 

posted @ 2018-04-17 19:03  Zinn  阅读(121)  评论(0编辑  收藏  举报