[ICPC 2022 Nanjing Regional E] Color the Tree

给定一个 \(n\) 个节点,以 \(1\) 号节点为根的有根树,初始时所有节点均为白色,给定 \(a_0\sim a_{n - 1}\),定义如下操作:

  • 选择一个节点 \(u\)\(i(0\le i\lt n)\),将 \(u\) 子树内与 \(u\) 的距离为 \(i\) 的节点全部涂黑,代价为 \(a_i\)
    求将全部节点涂黑的最小代价。

\(\mathtt{Data\ Range:}\sum n\le 3\times 10^5, 1\le a_i\le 10^9\)

前排提示:意识流题解。

分析

有两种形式不同的方法,最后我们会知道:两种方法的本质相同。

Solution 1:虚树 dp

观察 1:将深度相同的节点划分为一个等价类。不同等价类之间完全独立。

观察 2:这类问题难贪,考虑 dp。

观察 3:考虑退化成一条链的情况。此时答案为 \(\sum\limits_{i=1}^n \min(a_{0\sim i - 1})\)

观察 4:对于一个等价类,可以使用虚树上 dp。

时间复杂度 \(\mathcal O(n\log n)\),因为这题虚树中的点是自己加的,不用再对 dfs 序排序,配合更优的 RMQ 算法可以做到 \(\mathcal O(n)\)

Solution 2: 长链剖分 dp

考虑设 \(f_{u, i}\) 表示 \(u\) 的子树内深度为 \(i\) 的节点全部涂黑的最小代价。
转移:\(f_{u, i}=\min(a_{i - dep_u}, \sum\limits_{v\in son(u)} f_{v, i})\)
套用长链剖分的模板,唯一需要处理的是 check min 环节。
注意到如果 \(f_{u, *}\) 如果要和 \(a_t\) check min,并且在处理 \(f_{fa_u,*}\) 的时候没有别的子树和 \(f_{u, *}\) 累加,而是直接继承,那么下一轮就要和 \(a_{t+1}\) check min。
依次类推,check min 的范围是一个区间,那么我们只需要维护这个区间,在访问到的时候再处理即可。
时间复杂度和虚树一样,普通写法 \(\mathcal O(n\log n)\),如果采用 \(\mathcal O(n)-\mathcal O(1)\) RMQ 复杂度变为 \(\mathcal O(n)\)

归一

仔细思考一下不难发现两种做法之间的联系:对于长链剖分 dp,\(f_{u, *}\) 只会在这一层对应的虚树节点上发生子树合并,只要适当维护,时间复杂度必然和虚树做法保持一致。
我对长链剖分的理解不深,大概只知道这类题的流程:按深度划分等价类,继承信息。其时间复杂度分析也可以挪用虚树的理论。
那么是否可以提出一个暴论:虚树写法在大部分时间复杂度要求不严的题目上都优于长剖?
至少在代码实现上,必然还是虚树写法更简单。

代码

使用好写的 \(\mathcal O(n\log n)\) 虚树 dp 写法。

#include <bits/stdc++.h>
#define pb emplace_back
#define fir first
#define sec second

using i64 = long long;
using pii = std::pair<int, int>;

void work() {
    int n;
    std::cin >> n;
    
    std::vector<int> a(n);
    for (auto& x : a) std::cin >> x;

    std::vector<int> lg(n + 1, 0);
    std::vector<std::vector<int>> st(20, std::vector<int>(n + 1, 0));
    for (int i = 0; i < n; ++i) st[0][i] = a[i];
    for (int i = 2; i <= n; ++i) lg[i] = lg[i >> 1] + 1;
    for (int i = 1; i <= lg[n]; ++i)
        for (int j = 0; j + (1 << i) - 1 < n; ++j)
            st[i][j] = std::min(st[i - 1][j], st[i - 1][j + (1 << i - 1)]);
    
    auto RMQ = [&](int l, int r) {
        if (l > r) return (int)1e9;
        int k = lg[r - l + 1];
        return std::min(st[k][l], st[k][r - (1 << k) + 1]);
    };

    std::vector<std::vector<int>> adj(n + 1);
    for (int i = 1; i < n; ++i) {
        int u, v;
        std::cin >> u >> v;
        adj[u].pb(v), adj[v].pb(u);
    }

    std::vector<int> dep(n + 1, 0), dfn(n + 1, 0);
    std::vector<std::vector<int>> f(20, std::vector<int>(n + 1, 0));
    int cnt = 0;
    std::vector<std::vector<int>> buc(n + 1);
    std::function<void(int, int)> dfs = [&](int u, int ff) {
        dep[u] = dep[ff] + 1;
        dfn[u] = ++cnt;
        f[0][u] = ff;
        buc[dep[u]].pb(u);
        for (int i = 1; (1 << i) <= dep[u]; ++i) f[i][u] = f[i - 1][f[i - 1][u]];
        for (auto& v : adj[u]) if (v != ff) dfs(v, u);
        return;
    };
    
    dfs(1, 0);
    auto LCA = [&](int u, int v) {
        if (dep[u] < dep[v]) std::swap(u, v);
        for (int i = 19; ~i; --i) {
            if (dep[u] - (1 << i) >= dep[v]) u = f[i][u];
            if (u == v) return u;
        }
        for (int i = 19; ~i; --i) if (f[i][u] != f[i][v]) u = f[i][u], v = f[i][v];
        return f[0][u];
    };

    
    i64 ans = 0;
    std::vector<std::vector<int>> G(n + 1);
    std::vector<bool> tag(n + 1, false);
    for (int i = 1; i <= n; ++i) {
        if (buc[i].empty()) continue;
        std::function<int(int, int)> DFS = [&](int u, int ff) {
            int ans = RMQ(i - dep[u], i - dep[ff] - 1);
            if (G[u].empty()) return ans;
            int s = 0;
            for (auto& v : G[u]) if (v != ff) {
                s += DFS(v, u);
                if (s >= ans) break;
            }
            return std::min(ans, s);
        };
        
        std::vector<int> p = buc[i];
        for (auto& v : p) tag[v] = true;
        for (int j = 1; j < buc[i].size(); ++j) {
            int z = LCA(buc[i][j - 1], buc[i][j]);
            if (!tag[z]) tag[z] = true, p.pb(z);
        }
        std::sort(p.begin(), p.end(), [&](const int& lhs, const int& rhs) {
            return dfn[lhs] < dfn[rhs];
        });
        for (int i = 1; i < p.size(); ++i) G[LCA(p[i - 1], p[i])].pb(p[i]);

        ans += DFS(p[0], 0);
        for (auto& v : p) tag[v] = false, G[v].clear();
    }

    std::cout << ans << '\n';

    return;
}

int main() {
    std::cin.tie(nullptr)->sync_with_stdio(false);
    int t;
    std::cin >> t;
    while (t--) work();
    return 0;
}
posted @ 2025-08-13 14:35  ImALAS  阅读(23)  评论(0)    收藏  举报