P8906 [USACO22DEC] Breakdown P 题解

P8906 [USACO22DEC] Breakdown P 题解

显然的套路是删边转化为加边。

考虑到维护整条路径不好维护,于是考虑转化维护 \(f_{i,k},g_{i,k}\) 分别表示 \(1,n\)\(i\) 走了 \(k\) 步时的最短路。那么此时 \(k\le 4\)

我们先考虑 \(f\) 的转移,\(g\) 的转移是等价的。

那么对于 \(x\rightarrow y\),若当前 \(1\)\(y\) 的最短距离是 \(t\),那么转移的时候先用 \(x\) 来更新 \(y\),然后就相当于从点 \(y\) 再走 \(p=k-t\) 步。显然 \(p\le 3\),那么这样的转移大概一次是 \(O(n^3)\) 的,然而我们一次更新能接受的复杂度是 \(O(n)\),于是无法接受。

容易一些的,我们首先考虑 \(p=2\) 的情形。考虑到最终会得到更新的也只有 \(n\) 个点,那么我们直接枚举这 \(n\) 个点 \(a\)。这时候我们需要知道从 \(y\)\(a\) 恰好走 \(2\) 步的最短路 \(d_{y,a}\)。那这个东西由于限制了只走 \(2\) 步,是可以在更新每一条边的时候顺带预处理的。于是我们解决了 \(p=2\)

对于 \(p=3\),不难发现现在更新一次的复杂度是 \(O(n^2)\),但是考虑到 \(p=3\) 时当且仅当 \(x=1\),也就是说只有 $n $ 个这样的更新,于是这样的复杂度还是 \(O(n^3)\) 的,可以通过。

其实这道题的核心是想到如何处理 \(p\le 3\) 时的情形。观察到距离为 \(2\) 可以做预处理时这题基本就做完了。

代码:

#include <bits/stdc++.h>
#define N 303
#define K 10
#define int long long
using namespace std;
int n, k;
int w[N][N];
struct node {
	int x, y;
} e[N * N];
int mp[N][N];
int f[N][K]; 
int g[N][K]; 
int d[N][N];
void cm(int &x, int y) {
	x = min(x, y);
}
int add(int x, int y) {
	int ans = 1e18 + 1;
	for (int ck = 0; ck <= 4; ck++)
		for (int a = 1; a <= n; a++)
			ans = min(ans, f[a][ck] + g[a][k - ck]);
	mp[x][y] = 1;
	for (int a = 1; a <= n; a++)
		if (mp[y][a])
			cm(d[x][a], w[x][y] + w[y][a]);
	for (int a = 1; a <= n; a++)
		if (mp[a][x])
			cm(d[a][y], w[a][x] + w[x][y]);
	for (int ck = 0; ck < 4; ck++)
		cm(f[y][ck + 1], f[x][ck] + w[x][y]);
	for (int ck = 0; ck < 4; ck++) {
		for (int a = 1; a <= n; a++)
			if (mp[y][a])
				cm(f[a][ck + 1], f[y][ck] + w[y][a]);
		for (int a = 1; a <= n; a++)
			cm(f[a][ck + 2], f[y][ck] + d[y][a]);
	} 
	if (x == 1) {
		for (int a = 1; a <= n; a++)
			for (int b = 1; b <= n; b++)
				if (mp[a][b])
					cm(f[b][4], f[a][3] + w[a][b]);
	} 
	for (int ck = 0; ck < 4; ck++)
		cm(g[x][ck + 1], g[y][ck] + w[x][y]);
	for (int ck = 0; ck < 4; ck++) {
		for (int a = 1; a <= n; a++)
			if (mp[a][x])
				cm(g[a][ck + 1], g[x][ck] + w[a][x]);
		for (int a = 1; a <= n; a++)
			cm(g[a][ck + 2], g[x][ck] + d[a][x]);
	} 
	if (y == n) {
		for (int a = 1; a <= n; a++)
			for (int b = 1; b <= n; b++)
				if (mp[b][a])
					cm(g[b][4], g[a][3] + w[b][a]);
	}
	return ans >= 1e18 ? -1 : ans;
}
stack<int>q;
signed main() {
	memset(f, 0x3f, sizeof f); 
	memset(g, 0x3f, sizeof g); 
	memset(d, 0x3f, sizeof d); 
	cin >> n >> k;
	f[1][0] = g[n][0] = 0;
	for (int i = 1; i <= n; i++)
		for (int j = 1; j <= n; j++)
			scanf("%lld",  &w[i][j]);
	for (int i = 1; i <= n * n; i++) 
		scanf("%lld%lld", &e[i].x, &e[i].y);
	for (int i = n * n; i; i--) 
		q.push(add(e[i].x, e[i].y));
	while (q.size())
		cout << q.top() << "\n", q.pop();
	return 0;
}
posted @ 2024-09-25 18:24  长安19路  阅读(53)  评论(0)    收藏  举报