洛谷 P8820 [CSP-S 2022] 数据传输 题解

首先考虑对于每一次询问暴力 DP。

\(f_{u,i}\) 表示从 \(s\) 开始,传到一个到 \(u\) 距离为 \(i\) 的点,需要的最短时间。

\(k = 3\) 时,可能会传到不在 \(s \to t\) 路径上的点 \(x\)。假设从 \(a\) 点经过 \(x\) 传输到 \(b\) 点,若 \(x\)\(s \to t\) 路径上最近点的距离大于等于 \(2\),则 \(dist(a,b) \le 2\)\(a\) 传到 \(b\) 没有必要经过 \(x\)。所以只需要考虑到路径上的点距离不超过 \(1\) 的点。

\(w_u = \begin{cases}\min \limits_{dist(u,v)\le 1} a_v &(k=3) \\ +\infty &(k \le 2)\end{cases}\)

\(k = 3\) 时,

\(f_{u,0} = \min\{f_{v,0}, f_{v,1}, f_{v,2}\} + a_u\)

\(f_{u,1} = \min\{f_{v,0}, f_{v,1} + w_u\}\)

\(f_{u,2} = f_{v,1}\)

时间复杂度是 \(O(nqk^2)\)

优化这个 DP。传输时间与方向无关,\(s \to t\) 可以拆成 \(s \to lca\)\(t \to lca\),只需要求出这两段的 DP 结果再合并即可。

定义广义矩阵乘法 \(C_{i,j} = \min\limits_{0 \le p < k}\{A_{i,p}+B_{p,j}\}\)

DP 转移矩阵:\(\begin{bmatrix} a_u & 0 & +\infty \\ a_u & w_u & 0 \\ a_u & +\infty & +\infty \end{bmatrix}\)

初始向量:\(\begin{bmatrix} a_u & +\infty & +\infty \end{bmatrix}\)

树链剖分,对于每一条重链预处理前缀矩阵乘积,用线段树维护区间矩阵乘积。询问时跳重链,如果当前点 \(x\)\(lca\) 在同一条重链上,线段树查询。否则一定是重链的前缀,直接乘前缀积即可。注意如果两段都选了 \(lca\),结果应该减去 \(a_{lca}\)

时间复杂度 \(O(nk^3 + qk^2 \log n)\),空间复杂度 \(O(nk^2)\),均优于用倍增维护矩乘的做法。其实时空限制 1s 256MB 也能过,但良心出题人没有卡。

代码:

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const int N = 2e5 + 5;
const ll INF = 1e18;
int n, q, k;
ll a[N], w[N];
vector<int> tree[N];

int dep[N], fa[N], siz[N], son[N];

void dfs1(int u, int pre) {
	dep[u] = dep[pre] + 1;
	fa[u] = pre;
	siz[u] = 1;
	for (int v : tree[u])
		if (v != pre) {
			dfs1(v, u);
			siz[u] += siz[v];
			if (siz[v] > siz[son[u]]) son[u] = v;
		}
}

int dfn[N], bac[N], top[N], cnt;

void dfs2(int u, int tp) {
	dfn[u] = ++cnt;
	bac[cnt] = u;
	top[u] = tp;
	if (son[u]) dfs2(son[u], tp);
	for (int v : tree[u])
		if (v != fa[u] && v != son[u])
			dfs2(v, v);
}

inline int LCA(int x, int y) {
	int tx = top[x], ty = top[y];
	while (tx != ty) {
		if (dep[tx] >= dep[ty]) x = fa[tx], tx = top[x];
		else y = fa[ty], ty = top[y];
	}
	if (dep[x] <= dep[y]) return x;
	return y;
}

struct Mat {
	int n;
	ll s[3][3];
	
	Mat() {}
	Mat(int n) : n(n) {
		for (int i = 0; i < n; ++i)
			for (int j = 0; j < k; ++j)
				s[i][j] = INF;
	}
	
	void init(int u) {
		n = k;
		s[0][0] = s[1][0] = s[2][0] = a[u];
		s[0][1] = s[1][2] = 0;
		s[1][1] = k == 3 ? w[u] : INF;
		s[0][2] = s[2][1] = s[2][2] = INF;
	}
	
	void start(int u) {
		n = 1;
		s[0][0] = a[u];
		s[0][1] = s[0][2] = INF;
	}
};

inline Mat operator * (const Mat &A, const Mat &B) {
	Mat C(A.n);
	for (int i = 0; i < A.n; ++i)
		for (int p = 0; p < k; ++p)
			for (int j = 0; j < k; ++j)
				C.s[i][j] = min(C.s[i][j], A.s[i][p] + B.s[p][j]);
	return C;
}

Mat val[N], pre[N];

#define ls(x) (x << 1)
#define rs(x) (x << 1 | 1)

Mat prod[N << 2];

inline void pushup(int x) {
	prod[x] = prod[rs(x)] * prod[ls(x)];
}

void build(int x = 1, int l = 1, int r = n) {
	if (l == r) {
		prod[x] = val[l];
		return;
	}
	int mid = (l + r) >> 1;
	build(rs(x), mid + 1, r);
	build(ls(x), l, mid);
	pushup(x);
}

void query(int L, int R, Mat &res, int x = 1, int l = 1, int r = n) {
	if (L <= l && r <= R) {
		res = res * prod[x];
		return;
	}
	int mid = (l + r) >> 1;
	if (R > mid) query(L, R, res, rs(x), mid + 1, r);
	if (L <= mid) query(L, R, res, ls(x), l, mid);
}

Mat res1, res2;

inline void ask(int x, int lca, Mat &res) {
	res.start(x);
	if (x == lca) return;
	x = fa[x];
	while (top[x] != top[lca]) {
		res = res * pre[dfn[x]];
		x = fa[top[x]];
	}
	query(dfn[lca], dfn[x], res);
}

inline ll solve(int lca) {
	ll ans = res1.s[0][0] + res2.s[0][0] - a[lca];
	for (int i = 0; i < k; ++i)
		for (int j = 0; j < k; ++j)
			if (i + j <= k)
				ans = min(ans, res1.s[0][i] + res2.s[0][j]);
	return ans; 
}

int main() {
	ios::sync_with_stdio(0);
	cin.tie(0);
	cin >> n >> q >> k;
	for (int i = 1; i <= n; ++i) {
		cin >> a[i];
		w[i] = a[i];
	}
	for (int i = 1; i < n; ++i) {
		int u, v;
		cin >> u >> v;
		tree[u].push_back(v);
		tree[v].push_back(u);
		w[u] = min(w[u], a[v]);
		w[v] = min(w[v], a[u]);
	}
	dfs1(1, 0);
	dfs2(1, 1);
	for (int i = 1; i <= n; ++i)
		val[dfn[i]].init(i);
	for (int i = 1; i <= n; ++i) {
		if (bac[i] == top[bac[i]])
			pre[i] = val[i];
		else
			pre[i] = val[i] * pre[i - 1];
	}
	build();
	while (q--) {
		int s, t, lca;
		cin >> s >> t;
		lca = LCA(s, t);
		ask(s, lca, res1);
		ask(t, lca, res2);
		cout << solve(lca) << '\n';
	}
	return 0;
}
posted @ 2022-11-01 23:50  猫猫NOIP2006  阅读(470)  评论(0)    收藏  举报