洛谷 [P1608] 最短路计数

最短路计数模版

本题要注意重边的处理

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstdlib>
#include <cmath>
#include <queue>
#include <cstring>
using namespace std;
const int MAXN = 2005;
int head[MAXN], n, m, nume, dist[MAXN], cnt[MAXN];
bool f[MAXN];
struct edge{
    int to, nxt, dis;
}e[MAXN * MAXN];
void adde(int from, int to, int dis) {
    e[++nume].to = to;
    e[nume].nxt = head[from];
    e[nume].dis = dis;
    head[from] = nume;
}
int init() {
    int rv = 0, fh = 1;
    char c = getchar();
    while(c < '0' || c > '9') {
        if(c == '-') fh = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9') {
        rv = (rv<<1) + (rv<<3) + c - '0';
        c = getchar();
    }
    return fh * rv;
}
struct node{
    int num, dis;
    bool operator < (const node & b) const{
        return dis > b.dis;
    }
}a[MAXN];
priority_queue <node> q;
void dij() {
    memset(dist, 0x3f, sizeof(dist));
    dist[1] = 0;
    q.push((node){1, 0});
    while(!q.empty()) {
        node u = q.top(); q.pop();
        if(f[u.num]) continue;
        f[u.num] = 1;
        for(int i = head[u.num]; i; i = e[i].nxt) {
            node v;
            v.num = e[i].to;
            if(dist[v.num] > dist[u.num] + e[i].dis) {
                dist[v.num] = dist[u.num] + e[i].dis;
                v.dis = dist[v.num];
                q.push(v);
            } 
        }
    }
}
void cnnt() {
    memset(f, 0, sizeof(f));
    for(int i = 1; i <= n; i++) a[i] = (node) {i, dist[i]};
    sort(a + 1, a + n + 1);
    cnt[1] = 1;
    for(int i = n; i >= 1; i--) {
        int u = a[i].num;
        for(int i = head[u]; i; i = e[i].nxt) {
            int v = e[i].to;
            if(dist[v] == dist[u] + e[i].dis) {
                cnt[v] += cnt[u];
            }
        }
    }
}
int ddd[MAXN][MAXN];
int main() {
    n = init(); m = init();
    memset(ddd, 0x3f, sizeof(ddd));
    for(int i = 1; i <= m; i++) {
        int u = init(), v = init(), dis = init();
        ddd[u][v] = min(ddd[u][v], dis);
    }
    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= n; j++) {
            if(ddd[i][j] != 0x3f3f3f3f) adde(i, j, ddd[i][j]);
        }
    }
    dij();
    cnnt();
    if(dist[n] == 0x3f3f3f3f) printf("No answer\n");
    else printf("%d %d\n", dist[n], cnt[n]);
    return 0;
}
posted @ 2018-04-03 19:44  Mr_Wolfram  阅读(...)  评论(...编辑  收藏