BZOJ2750: [HAOI2012]Road

题目:http://www.lydsy.com/JudgeOnline/problem.php?id=2750

跑n遍spfa,然后对于一条边,我们记录有多少最短路可以走到出点,记为a,从入点出发有多少最短路,记为b,那么每次就加上a[u]*b[v]就可以了。

(数据范围开错傻逼好久。。。药丸

#include<cstring>
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<queue>
#define rep(i,l,r) for (int i=l;i<=r;i++)
#define down(i,l,r) for (int i=l;i>=r;i--)
#define clr(x,y) memset(x,y,sizeof(x))
#define mm 1000000007
#define ll long long
#define maxm 5050
#define maxn 2050
using namespace std;
struct data{int from,obj,pre,c;
}e[maxm];
int tot,n,m;
int head[maxn],vis[maxn],dis[maxn],d[maxn];
ll ans[maxm],a[maxn],b[maxn];
int read(){
    int x=0,f=1; char ch=getchar();
    while (!isdigit(ch)) {if (ch=='-') f=-1; ch=getchar();}
    while (isdigit(ch)) {x=x*10+ch-'0'; ch=getchar();}
    return x*f;
}
void insert(int x,int y,int z){
    e[++tot].obj=y; e[tot].from=x;  e[tot].c=z; e[tot].pre=head[x]; head[x]=tot;
}
void spfa(int s){
    queue<int >q; clr(dis,127/3); clr(vis,0); dis[s]=0;  
    q.push(s);
    while (!q.empty()){
        int u=q.front(); q.pop(); vis[u]=1;
        for (int j=head[u];j;j=e[j].pre){
            int v=e[j].obj;
            if (dis[v]>dis[u]+e[j].c){
                dis[v]=dis[u]+e[j].c;
                if (!vis[v]) {
                    vis[v]=1; q.push(v);
                }
            }
        }
        vis[u]=0;
    }
}
void dfs(int u){
    for (int j=head[u];j;j=e[j].pre){
        int v=e[j].obj;
        if (dis[v]==dis[u]+e[j].c){
            a[v]=(a[v]+a[u])%mm;
            d[v]--;
            if (!d[v])dfs(v);
        }
    }
}
void dfs2(int u){
    b[u]=1;
    for (int j=head[u];j;j=e[j].pre){
        int v=e[j].obj;
        if (dis[v]==dis[u]+e[j].c){
            if (!b[v]) dfs2(v);
            b[u]=(b[u]+b[v])%mm;
        }
    }
}
void get(int u){
    vis[u]=1;
    for (int j=head[u];j;j=e[j].pre){
        int v=e[j].obj;
        if (dis[v]==dis[u]+e[j].c){
            d[v]++;
            if (!vis[v]) get(v);
        }
    }
}
int main(){
    n=read(); m=read();
    int x,y,z;
    rep(i,1,m){
        x=read(); y=read(); z=read();
        insert(x,y,z);
    } 
    rep(i,1,n){
        spfa(i); clr(a,0); clr(b,0); clr(vis,0); clr(d,0); 
        get(i);
        a[i]=1; dfs(i);
        dfs2(i);
        rep(j,1,m) if (dis[e[j].from]+e[j].c==dis[e[j].obj]) ans[j]=(ans[j]+a[e[j].from]*b[e[j].obj])%mm;
    }
    rep(i,1,m) printf("%lld\n",ans[i]);
    return 0;
}

 

posted on 2015-12-02 14:27  ctlchild  阅读(233)  评论(0编辑  收藏  举报

导航