P3354 [IOI 2005] Riv 河流

看完题第一反应大概是 dp[u, j] 表示以u为根, 并且选择了j个节点的最小花费, 可是这样我们会发现: 我们好像没办法计算上一个点到这一个点的花费, 因为花费是 $w[i] \times (dep[u] - dep[lst])$, 而在DP中, 我们应该只关注上一个状态(无后效性), 所以这个DP状态设计是有问题的;

根据上一个DP状态设计的缺点我们可以发现, 我们要知道每个点最后会会聚到哪个点, 这样我们才能计算出花费, 所以我们这样设计 dp[u, i, j] 表示u最终会聚到i, 并且选择了j个节点的最小花费; 然后我们考虑转移, 每个点有选和不选两种状态:

  • 如果是不选择的话很好转移: cur = min(cur, dp[u, fa, i - j] + dp[v, fa, j])(这里不能直接用dp[u, fa, j]代替cur, 因为当j=0的时候, 如果使用dp[u, fa, j], 就有因为dp[u, fa, i] + dp[v, i, 0] > dp[u, fa, i]而忽略掉dp[v, fa, 0]的值, 而j=0表示v节点不贡献任何选择点, 但是要加上v的花费, 所以一定要算上dp[v, fa, 0]);
  • 如果是选的话, 我们再背包问题的一般做法是: 先不管这个节点, 把其他的物品合并之后, 再平移加上这个点的体积和价值, 所以在这里也是一样的, 那么就出现了一个问题, 如果直接使用dp[u, i, j]的话, 其实是包含了不能平移的部分, 所以我们需要再开一个新的数组来处理这个需要平移的背包, 最后再进行合并, 我们用 f[u, j] 表示以u节点为根, 选了u节点, 并且总共选择了j个节点的花费;

最后就是代码了, 很多细节要注意:

inline void solve() {
    int n, m;
    cin >> n >> m;

    vector<vpii> g(n + 1);
    vi a(n + 1);
    for (int i = 1; i <= n; i++) {
        int v, d;
        cin >> a[i] >> v >> d;
        g[i].push_back({v, d});
        g[v].push_back({i, d});
    }

    vector<vvi> dp(n + 1, vvi(n + 1, vi(m + 1, inf)));
    vvi f(n + 1, vi(m + 1, inf));
    vi stk, dep(n + 1);
    auto dfs = [&](auto&& dfs, int u, int fa) -> void {
        stk.push_back(u);
        f[u][0] = 0;
        for (int fa : stk) dp[u][fa][0] = a[u] * (dep[u] - dep[fa]);
        for (auto [v, w] : g[u]) {
            if (v == fa) continue;
            dep[v] = dep[u] + w, dfs(dfs, v, u);

            for (int i = m; i >= 0; i--) {
                int tmp = inf;
                for (int j = 0; j <= i; j++) {
                    if (f[u][i - j] != inf && dp[v][u][j] != inf) {
                        tmp = min(tmp, f[u][i - j] + dp[v][u][j]);
                    }
                }
                f[u][i] = tmp;
            }

            for (int fa : stk) {
                for (int i = m; i >= 0; i--) {
                    int tmp = inf;
                    for (int j = 0; j <= i; j++) {
                        if (dp[v][fa][j] != inf && dp[u][fa][i - j] != inf) {
                            tmp = min(tmp, dp[v][fa][j] + dp[u][fa][i - j]);
                        }
                    }
                    dp[u][fa][i] = tmp;
                }
            }
        }

        for (int fa : stk) {
            for (int i = m; i >= 1; i--) {
                dp[u][fa][i] = min(dp[u][fa][i], f[u][i - 1]);
            }
        }

        stk.pop_back();
    };
    dfs(dfs, 0, -1);

    cout << dp[0][0][m] << endl;
}
posted @ 2026-06-07 12:44  RCells  阅读(1)  评论(0)    收藏  举报