[Luogu] P5021 赛道修建
Description
\(C\)城将要举办一系列的赛车比赛。在比赛前,需要在城内修建\(m\)条赛道。
\(C\)城一共有\(n\)个路口,这些路口编号为\(1,2,…,n\),有\(n−1\)条适合于修建赛道的双向通行的道路,每条道路连接着两个路口。其中,第\(i\)条道路连接的两个路口编号为\(a_i\)和\(b_i\),该道路的长度为\(l_i\)。借助这\(n-1\)条道路,从任何一个路口出发都能到达其他所有的路口。
一条赛道是一组互不相同的道路\(e_1,e_2,…,e_k\),满足可以从某个路口出发,依次经过道路\(e_1,e_2,…,e_k\)(每条道路经过一次,不允许调头)到达另一个路口。一条赛道的长度等于经过的各道路的长度之和。为保证安全,要求每条道路至多被一条赛道经过。
目前赛道修建的方案尚未确定。你的任务是设计一种赛道修建的方案,使得修建的\(m\)条赛道中长度最小的赛道长度最大(即\(m\)条赛道中最短赛道的长度尽可能大)
Solution
首先看到最小的最大,就想到要二分。我们二分\(m\)条赛道中长度最小的赛道长度\(mid\),那么\(check\)的就是长度\(\ge{mid}\)的赛道最多能不能\(\ge{m}\),如果可以,那么\(l=mid+1\),否则\(r=mid-1\)。
这个还是可以想到的,就是\(check\)不太好写。注意到对于某个节点\(x\)和它的一个子节点\(y\),一定是选若干条以\(y\)为链顶的链,将剩下的最短的链\(l_1\)和对应最短的满足\(len_1+len_2\ge{mid}\)的链\(l_2\)拼成一条赛道,然后如果最后还剩下一条链\(l_3\),长度就一定是最大的,将它和\(x\rightarrow{y}\)这条边拼在一起,构成一条赛道。
这其实就是贪心的思想。因为\(x\rightarrow{y}\)这条边一定能且只能和一条以\(y\)为顶的链构成赛道,然后如果把某条\(l_1\)对应的\(l_2\)替换成更大的,可能反而会找不到一条链\(l_3\),满足\(len_3+len_{x\rightarrow{y}}\ge{mid}\)。
具体实现用\(multiset\)。
Code
#include <bits/stdc++.h>
using namespace std;
int n, m, tot, res, sum, l = 1e9, r, hd[50005], to[100005], nxt[100005], w[100005];
multiset < int > g[50005];
int read()
{
	int x = 0, fl = 1; char ch = getchar();
	while (ch < '0' || ch > '9') { if (ch == '-') fl = -1; ch = getchar();}
	while (ch >= '0' && ch <= '9') {x = (x << 1) + (x << 3) + ch - '0'; ch = getchar();}
	return x * fl;
}
void add(int x, int y, int z)
{
	tot ++ ;
	to[tot] = y;
	nxt[tot] = hd[x];
	w[tot] = z;
	hd[x] = tot;
	return;
}
int dfs(int x, int fa, int d)
{
	g[x].clear();
	for (int i = hd[x]; i; i = nxt[i])
	{
		int y = to[i], z = w[i];
		if (y == fa) continue;
		int now = dfs(y, x, d) + z;
		if (now >= d) sum ++ ;
		else g[x].insert(now);
	}
	int mx = 0;
	while (g[x].size())
	{
		int cnt = *g[x].begin();
		if (g[x].size() == 1) return max(mx, cnt); 
		multiset < int > :: iterator it = g[x].lower_bound(d - cnt);
		if (it == g[x].begin()) it ++ ;
		g[x].erase(g[x].begin());
		if (it == g[x].end()) mx = max(mx, cnt);
		else sum ++ , g[x].erase(it);
	}
	return mx;
}
int check(int x)
{
	sum = 0;
	dfs(1, 0, x);
	return (sum >= m);
}
int main()
{
	n = read(); m = read();
	for (int i = 1; i <= n - 1; i ++ )
	{
		int x = read(), y = read(), z = read();
		add(x, y, z); add(y, x, z);
		l = min(l, z); r += z;
	}
	while (l <= r)
	{
		int mid = (l + r) >> 1;
		if (check(mid)) l = mid + 1, res = mid;
		else r = mid - 1;
	}
	printf("%d\n", res);
	return 0;
}

 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号