[Codeforces 1016F]Road Projects

Description

题库链接

给你一棵 $n$ 个节点的树,定义 $1$ 到 $n$ 的代价是 $1$ 到 $n$ 节点间的最短路径的长度。现在给你 $m$ 组询问,让你添加一条边权为 $w$ 的边(不与原图重复),求代价的最大值。询问之间相互独立。

$1\leq n,m\leq 3\times 10^5$

Solution

辣鸡老余毁我青春

把 $1$ 至 $n$ 的路径提取出来,显然图就变成了一条链加上若干子树。贪心的思想是找这样一组点,满足

  1. 其所属的子树来自链上两个不同的点(可以选链上的点)
  2. 一定是该子树内最深的那个点

然后我们可以遍历这条链,找出满足上述所有情况的最大值。最后 $O(1)$ 回答询问即可。

注意的是要特判 $1$ 或 $n$ 结点外连边的情况。

Code

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 300000+5;

struct tt {int to, next, cost; } edge[N<<1];
int path[N], top;
int n, m, u, v, c, vis[N], sz[N], flag, f, szf;
ll dist[N], maxn = -1, ans, dis[N];

bool dfs(int u, int fa) {
    sz[u] = 1;
    for (int v, i = path[u]; i; i = edge[i].next)
        if ((v = edge[i].to) != fa) {
            dis[v] = dis[u]+edge[i].cost;
            if (dfs(v, u)) vis[u] = 1;
            else dist[u] = max(dist[u], dist[v]+edge[i].cost), flag += (u == n), f += (u == 1);
            sz[u] += sz[v];
        }
    return vis[u] |= (u == n);
}
void dfs2(int u, int fa) {
    for (int v, i = path[u]; i; i = edge[i].next)
        if ((v = edge[i].to) != fa && vis[v]) {
            if (u == 1) szf = sz[u]-sz[v];
            if (dist[fa] > 0) maxn = max(maxn, dist[fa]+dis[fa]);
            if (dist[u] > 0) maxn = max(maxn, dis[fa]);
            if (maxn != -1 && u != 1) ans = max(ans, maxn+dist[u]+dis[n]-dis[u]);
            if (dist[fa] == 0 && fa != 0) maxn = max(maxn, dis[fa]);
            dfs2(v, u);
        }
    if (u == n) {
        if (dist[fa] > 0) maxn = max(maxn, dist[fa]+dis[fa]);
        if (dist[u] > 0) maxn = max(maxn, dis[fa]);
        if (maxn != -1) ans = max(ans, maxn+dist[u]);   
    }
}
void add(int u, int v, int c) {edge[++top] = (tt){v, path[u], c}; path[u] = top; }
void work() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i < n; i++) {
        scanf("%d%d%d", &u, &v, &c);
        add(u, v, c), add(v, u, c);
    }
    dfs(1, 0); dfs2(1, 0);
    while (m--) {
        scanf("%d", &c);
        if (flag > 1 || sz[n] >= 3) printf("%I64d\n", dis[n]);
        else if (f > 1 || szf >= 3) printf("%I64d\n", dis[n]);
        else printf("%I64d\n", min(c+ans, dis[n])); 
    }
}
int main() {work(); return 0; }
posted @ 2018-08-31 10:03  NaVi_Awson  阅读(388)  评论(0编辑  收藏  举报