全源最短路——Johnson 算法

一、问题引入

目前我们所知道的一些常见的最短路算法有 dijkstra、spfa、floyd。

dijkstra 和 spfa 是单源最短路,floyd 是全源最短路。

如果我们需要在 \(O(nm)\) 等级的时间复杂度下求出全源最短路,并且图存在负权,那么它就叉掉了这三种最短路算法,因为 dijkstra 无法处理负权,spfa 跑 \(n\) 次虽然一般跑不满,但是只要卡一下,就可以卡到 \(O(n^2m)\) 的时间复杂度,floyd 时间复杂度 \(O(n^3)\),也过不了,这个时候,就出现了 Johnson 算法,它是一种依靠 dijkstra 和 spfa 的算法。

二、Johnson 算法流程

  • 新建超级源点,给每个点连接一条边。
  • 算出每个点到超级源点的最短距离 \(h_i\),使用 spfa,因为此时存在负权。
  • 给每条边的 \(w_i\) 更新为 \(h_u-h_v+w_i\)\(u\)\(v\) 为这条边连接的两个节点。
  • 计算最短路,使用 dijkstra,此时已经不存在负权。
  • 输出时要减去 \(h_j-h_i\)

三、算法正确性证明

考虑如何将 dijkstra 优化成可以求负边权的算法。

首先有一种很容易想到的思路,就是将所有边都加上一个数,使得所有边的权值都变成非负整数,但是这种想法是错的,因为如果这样就会出现路径经过的边数不同导致最短路计算错误,所以我们需要使每个点加上一个数,使得任何一个最短路多余的值进行消掉之后只剩下开头点和结尾点,才能正确算出最短路。所以我们要将每一个点 \(i\) 设置一个值 \(h_i\),然后假设一条路径的边进行加工后,假设它的路径起点为 \(s\),终点为 \(t\),则路径为 \(s \rightarrow p_1 \rightarrow p_2 \rightarrow p_3 \rightarrow \dots \rightarrow p_k \rightarrow t\),其权值为 \(w(s,p_1)+h_s-h_{p_1}+w(p_1,p_2)+h_{p_1}-h_{p_2}+w(p_2,p_3)+h_{p_2}-h_{p_3}+\dots+w(p_k,t)+h_{p_k}-h_t\),消掉后变成了 \(w(s,p_1)+w(p_1,p_2)+w(p_2,p_3)+\dots+w(p_k,t)+h_s-h_t\),于是这样权值就和边的数量一点关系也没有了,目前已经证明了一半,那如何证明更新后的边权一定非负?首先对于任意一条边 \((u,v)\),它一定满足 \(h_v \le h_u+w(u,v)\),移项后得 \(0 \le h_u-h_v+w(u,v)\),则 \(h_u-h_v+w(u,v) \ge 0\),于是我们就证明了任意一条边更新后权值绝对非负。

时间复杂度为 \(O(nm \log m)\)(spfa 的时间复杂度是 \(O(nm)\),重定义边权的时间复杂度是 \(O(n+m)\),dijkstra 的时间复杂度是 \(O(nm \log m)\),算时间复杂度肯定取最大值,所以是 \(O(nm \log m)\))。

四、例题

P5905 【模板】全源最短路(Johnson)

代码:

#include <bits/stdc++.h>
using namespace std;
namespace fast_IO {
#define FASTIO
#define IOSIZE 100000
	char ibuf[IOSIZE], obuf[IOSIZE];
	char *p1 = ibuf, *p2 = ibuf, *p3 = obuf;
#ifdef ONLINE_JUDGE
#define getchar() ((p1==p2)and(p2=(p1=ibuf)+fread(ibuf,1,IOSIZE,stdin),p1==p2)?(EOF):(*p1++))
#define putchar(x) ((p3==obuf+IOSIZE)&&(fwrite(obuf,p3-obuf,1,stdout),p3=obuf),*p3++=x)
#endif//fread in OJ, stdio in local
	
#define isdigit(ch) (ch>47&&ch<58)
#define isspace(ch) (ch<33)
	template<typename T> inline T read() {
		T s = 0;
		int w = 1;
		char ch;
		while (ch = getchar(), !isdigit(ch) and (ch != EOF)) if (ch == '-') w = -1;
		if (ch == EOF) return false;
		while (isdigit(ch)) s = s * 10 + ch - 48, ch = getchar();
		return s * w;
	}
	template<typename T> inline bool read(T &s) {
		s = 0;
		int w = 1;
		char ch;
		while (ch = getchar(), !isdigit(ch) and (ch != EOF)) if (ch == '-') w = -1;
		if (ch == EOF) return false;
		while (isdigit(ch)) s = s * 10 + ch - 48, ch = getchar();
		return s *= w, true;
	}
	inline bool read(char &s) {
		while (s = getchar(), isspace(s));
		return true;
	}
	inline bool read(char *s) {
		char ch;
		while (ch = getchar(), isspace(ch));
		if (ch == EOF) return false;
		while (!isspace(ch)) *s++ = ch, ch = getchar();
		*s = '\000';
		return true;
	}
	template<typename T> inline void print(T x) {
		if (x < 0) putchar('-'), x = -x;
		if (x > 9) print(x / 10);
		putchar(x % 10 + 48);
	}
	inline void print(char x) {
		putchar(x);
	}
	inline void print(char *x) {
		while (*x) putchar(*x++);
	}
	inline void print(const char *x) {
		for (int i = 0; x[i]; i++) putchar(x[i]);
	}
#ifdef _GLIBCXX_STRING
	inline bool read(std::string& s) {
		s = "";
		char ch;
		while (ch = getchar(), isspace(ch));
		if (ch == EOF) return false;
		while (!isspace(ch)) s += ch, ch = getchar();
		return true;
	}
	inline void print(std::string x) {
		for (int i = 0, n = x.size(); i < n; i++)
			putchar(x[i]);
	}
#endif//string
	template<typename T, typename... T1> inline int read(T& a, T1&... other) {
		return read(a) + read(other...);
	}
	template<typename T, typename... T1> inline void print(T a, T1... other) {
		print(a);
		print(other...);
	}
	
	struct Fast_IO {
		~Fast_IO() {
			fwrite(obuf, p3 - obuf, 1, stdout);
		}
	} io;
	template<typename T> Fast_IO& operator >> (Fast_IO &io, T &b) {
		return read(b), io;
	}
	template<typename T> Fast_IO& operator << (Fast_IO &io, T b) {
		return print(b), io;
	}
#define cout io
#define cin io
#define endl '\n'
}
using namespace fast_IO;
const int N = 3e3+5;
struct node
{
	int x;
	int w;
	int operator<(const node&a)const
	{
		return w>a.w;
	}
};
vector<node>a[N];
int vis[N];
long long h[N];
long long d[N];
int t[N];
signed main()
{
	int n,m;
	cin >> n >> m;
	for(int i = 1;i<=m;i++)
	{
		int x,y,w;
		cin >> x >> y >> w;
		a[x].push_back({y,w});
	}
	for(int i = 1;i<=n;i++)
	{
		a[0].push_back({i,0});
	}
	memset(h,0x3f,sizeof(h));
	queue<int>q;
	q.push(0);
	vis[0] = 1;
	h[0] = 0;
	while(q.size())
	{
		int x = q.front();
		q.pop();
		vis[x] = 0;
		for(int i = 0;i<a[x].size();i++)
		{
			int v = a[x][i].x;
			int w = a[x][i].w;
			if(h[v]>h[x]+w)
			{
				h[v] = h[x]+w;
				if(!vis[v])
				{
					vis[v] = 1;
					q.push(v);
					t[v]++;
					if(t[v] == n+1)
					{
						printf("-1");
						return 0;
					}
				}
			}
		}
	}
	for(int i = 1;i<=n;i++)
	{
		for(int j = 0;j<a[i].size();j++)
		{
			int v = a[i][j].x;
			a[i][j].w+=h[i]-h[v];
		}
	}
	for(int i = 1;i<=n;i++)
	{
		priority_queue<node>q;
		q.push({i,0});
		memset(d,0x3f,sizeof(d));
		memset(vis,0,sizeof(vis));
		d[i] = 0;
		while(q.size())
		{
			int x = q.top().x;
			q.pop();
			if(vis[x])
			{
				continue;
			}
			vis[x] = 1;
			for(int i = 0;i<a[x].size();i++)
			{
				int v = a[x][i].x;
				int w = a[x][i].w;
				if(d[v]>d[x]+w)
				{
					d[v] = d[x]+w;
					q.push({v,d[v]});
				}
			}
		}
		long long sum = 0;
		for(int j = 1;j<=n;j++)
		{
			if(d[j] == d[0])
			{
				sum+=(long long)j*(long long)1000000000;
			}
			else
			{
				sum+=j*(d[j]+h[j]-h[i]);
			}
		}
		cout << sum << "\n";
	}
	return 0;
}
posted @ 2025-01-17 16:05  林晋堃  阅读(234)  评论(0)    收藏  举报