题解:P10838 『FLA - I』庭中有奇树
幽默样例写错两个地方反而能过导致调了 1h。
不难发现,小 G 的策略只有三种,不作弊老实走,走封锁线路花费 \(10^9\) 速通,或者作弊。
前两种的值都是定值,可以很快求出来。现在来考虑第三种情况。
由于最多封锁 \(m\) 条线路,所以在作弊情况下的第 \(m + 1\) 短路就是作弊后的最优解。所以很容易想到构造分层图跑 \(k\) 短路。
我们可以预处理出每个点分别到 \(s\) 和 \(t\) 的最短路程,再枚举传送的两个点 \(u,v\),将所有值统计出来最后排个序,取第 \(m + 1\) 个,时间复杂度 \(O(n^2)\)。
考虑二分答案,对于当前答案 \(x\),只要路程小于等于 \(x\) 的路径多于 \(m\) 条,则 \(x\) 可行。问题转化为如何统计路程小于等于 \(x\) 的路径条数。
对于一条使用传送的路径来说,总路程分为三部分,\(s\) 走树边到 \(u\),耗费 \(ds_u\);\(u\) 传送到 \(v\),耗费 \(k\);\(v\) 走树边到 \(t\),耗费 \(dt_v\)。(\(ds_i\) 和 \(dt_i\) 分别表示 \(i\) 到 \(s\) 和 \(i\) 到 \(t\) 的距离。)
即 \(dis = ds_u + k + dt_v\)。
我们要求 \(dis \le x\) 的个数,枚举 \(u\),此时 \(ds_u\) 与 \(k\) 均为定值,即求 \(dt_v \le x - k - ds_u\) 的个数,将 \(dt\) 排序二分即可。
最终得到的答案再与不作弊老实走和走封锁线路花费 \(10^9\) 速通的答案比较,取最小值。
时间复杂度 \(O(n\log n \log V)\),\(V\) 是二分值域,开到 \(10^9\) 足够。若大于 \(10^9\) 则不如走封锁路线速通。
代码如下:
#include<bits/stdc++.h>
#define MAXN 500010
#define MAXM 1000010
using namespace std;
typedef long long ll;
struct edge{ ll pre, to, w; };
ll n, m, k, s, t, cnt;
edge e[MAXM];
ll head[MAXN], deep[MAXN], vis[MAXN], dis[MAXN], ans[MAXN], dis_s[MAXN], dis_t[MAXN], dt[MAXN];
ll fa[MAXN][30];
void add_edge(ll u, ll v, ll w){
e[++cnt].pre = head[u];
e[cnt].to = v; e[cnt].w = w;
head[u] = cnt;
}
void bfs(ll root){
queue<ll> s;
s.push(root);
vis[root] = true;
deep[root] = 0; dis[root] = 0;
while(!s.empty()){
ll p = s.front(); s.pop();
for(ll i = head[p]; i; i = e[i].pre){
if(!vis[e[i].to]){
vis[e[i].to] = true;
deep[e[i].to] = deep[p] + 1;
dis[e[i].to] = dis[p] + e[i].w;
fa[e[i].to][0] = p;
s.push(e[i].to);
}
}
}
}
ll get_lca(ll x, ll y){
if(deep[x] < deep[y]) swap(x, y);
ll maxi=0;
while((1 << maxi) <= deep[x]) maxi++;
maxi--;
for(ll i = maxi; i >= 0; i--){
if(deep[x] - (1 << i) >= deep[y]){
x = fa[x][i];
}
}
if(x == y) return x;
for(ll i = maxi; i >= 0; i--){
if(fa[x][i] != fa[y][i]) x=fa[x][i], y=fa[y][i];
}
return fa[x][0];
}
ll get_dis(int u, int v){
ll lca = get_lca(u, v);
ll ans = dis[u] + dis[v] - dis[lca] * 2;
return ans;
}
bool check(ll x){
ll cnt = 0;
for(int i = 1; i <= n; i++){
ll tmp = x - dis_s[i];
cnt += upper_bound(dt + 1, dt + n + 1, tmp) - dt - 1;
for(int j = head[i]; j; j = e[j].pre){
if(dis_t[e[j].to] <= tmp) cnt--;
}
if(dis_t[i] <= tmp) cnt--;
}
return cnt > m;
}
int main(){
scanf("%lld%lld%lld%lld%lld",&n,&m,&k,&s,&t);
for(ll i = 1; i < n; i++){
ll u, v, w;
scanf("%lld%lld%lld",&u,&v,&w);
add_edge(u, v, w); add_edge(v, u, w);
}
bfs(1);
for(ll i = 1; (1 << i) <= n; i++){
for(ll j = 1; j <= n; j++){
fa[j][i] = fa[fa[j][i - 1]][i - 1];
}
}
ll ans = get_dis(s, t);
for(int i = 1; i <= n; i++){
dis_s[i] = get_dis(i, s);
dt[i] = dis_t[i] = get_dis(i, t);
}
sort(dt + 1, dt + n + 1);
ll l = 0, r = 1e9, res = r;
while(l < r){
ll mid = (l + r) >> 1;
if(check(mid)){
r = mid;
res = mid;
}else{
l = mid + 1;
}
}
res = min(res + k, 1000000000ll);
printf("%lld\n",min(ans, res));
}