树上背包优化

树上背包优化

树形背包
这道题卡 nw

背景

这是在上课的时候打的,就长话短说。这道题其实我还是不是很明白,不过如果是刷表的方式的话,代码虽然常数会大一点但胜在易于理解。

但如果是打表,它是从孩子向父亲或者兄弟转移的,这种方式我只看了个半懂,这道题是抄抄过了,但是在写另一道更简单的板子的时候却写 \(wa\) 了。直到现在我也并没有很明白,所以这篇博客又是来埋雷的。

假如有这么一棵树,我们画出它的 \(dfs\) 序:

pZAYxxK.png

发现加上这个点的 \(siz\) 之后可以跳过它的儿子。对于刷表法,我们现在就可以从根节点开始,对儿子的可能状态取最大值。但是按后序遍历刷表怎么实现呢?

观察一段转移的过程。

pZAtAPI.png

但前节点加一会到它的父亲或者兄弟,可以想象兄弟之间不断传递最优,最后再到父亲。

在这个传递过程中我们有两种可能。

  1. 当前枚举的背包容量不够,显然必须跳过但前节点的所有儿子。也就是从 \(dp_{i-siz_j,j} \to dp_{i,j}\)
  2. 如果容量够,那么就对取和不取它的子树中取最大值,取的话剩余背包容量要减去当前所需,权值要加上这个点的权值。

转移就很显然了。直接留下一份 \(ai\) 的代码。

#include <bits/stdc++.h>
#define int long long
using namespace std;
constexpr int maxn = 5e4+10;

vector<int> gra[maxn];
vector<int> siz, vi, wi;    // 子树大小、节点重量、节点价值
vector<int> d_siz, d_v, d_w;// DFS序列化后的数据
vector<vector<int>> dp;
int n, V;

void dfs(int u)
{
    siz[u] = 1;

    for (int child : gra[u])
    {
        dfs(child);
        siz[u] += siz[child];
    }

    // 将当前节点信息加入DFS序列
    d_siz.push_back(siz[u]);
    d_v.push_back(vi[u]);
    d_w.push_back(wi[u]);
}

signed main()
{
    #ifndef ONLINE_JUDGE
        freopen("sf.in","r",stdin);
    #endif // ONLINE_JUDGE

    scanf("%lld%lld",&n,&V);

    dp.resize(n+1,vector<int>(V+1,0));
    siz.resize(n+1);
    vi.resize(n+1);
    wi.resize(n+1);

    for (int i = 1,p ; i <= n; ++i)
    {
        scanf("%lld",&p);
        gra[p].push_back(i);
    }

    for (int i = 1; i <= n; ++i)
    {
        scanf("%lld",&vi[i]);
    }
    for (int i = 1; i <= n; ++i)
    {
        scanf("%lld",&wi[i]);
    }

    // 1_base
    d_siz.push_back(0);
    d_v.push_back(0);
    d_w.push_back(0);

    // 从根节点0开始DFS遍历
    dfs(0);

    // 清空原始数据释放内存
    vi.clear();
    wi.clear();
    siz.clear();

    for (int i = 1; i <= n; ++i)
    {
        for (int j = 0; j <= V; ++j)
        {
            if (j >= d_v[i])
            {
                // 选择当前物品:dp[i-1][j - v] + w
                // 不选择当前子树:dp[i - siz][j]
                dp[i][j] = max(dp[i - 1][j - d_v[i]] + d_w[i],
                               dp[i - d_siz[i]][j]);
            }
            else
            {
                // 容量不足,只能选择不包含当前子树
                dp[i][j] = dp[i - d_siz[i]][j];
            }
        }
    }

    printf("%lld\n",dp[n][V]);
    return 0;
}

题外话

我前面不是说在另一到题写瓦了吗。我来贴一下代码和疑问。原题

// wa
#include <bits/stdc++.h>
#define int long long
using namespace std;
constexpr int maxn = 5e2+10;

int n,m,dp[maxn][maxn],wi[maxn];
int siz[maxn];
vector<int> gra[maxn];
int id[maxn],idx;

void dfs(int u)
{
    siz[u]=1;
    for(const int &v : gra[u])
    {
        dfs(v);
        siz[u]+=siz[v];
    }
    id[++idx]=u;
}

signed main()
{
    #ifndef ONLINE_JUDGE
    freopen("cjdl.in","r",stdin);
    freopen("cjdl.out","w",stdout);
    #endif // ONLINE_JUDGE

    scanf("%lld%lld",&n,&m);
    for(int i=1,u ;i<=n;++i)
    {
        scanf("%lld%lld",&u,wi+i);
        gra[u].emplace_back(i);
    }

    dfs(0);

    for(int i=1;i<=n;++i)
    {
        for(int j=1;j<=m+1;++j)
        {
            // 这里点的代价是 1 
            dp[i][j]=max(dp[i-siz[id[i]]][j],dp[i-1][j-1]+wi[id[i]]);
        }
    }

    printf("%lld\n",dp[n][m+1]);

    return 0;
}

\(ac_code\)

#include <bits/stdc++.h>
#define int long long
using namespace std;
constexpr int maxn = 5e2+10;

int n,m,dp[maxn][maxn],wi[maxn];
int siz[maxn];
vector<int> gra[maxn];
int id[maxn],idx;

void dfs(int u)
{
    id[++idx]=u;// 先序遍历
    siz[u]=1;
    for(const int &v : gra[u])
    {
        dfs(v);
        siz[u]+=siz[v];
    }
}

signed main()
{
    #ifndef ONLINE_JUDGE
    freopen("cjdl.in","r",stdin);
    freopen("cjdl.out","w",stdout);
    #endif // ONLINE_JUDGE

    scanf("%lld%lld",&n,&m);
    for(int i=1,u ;i<=n;++i)
    {
        scanf("%lld%lld",&u,wi+i);
        gra[u].emplace_back(i);
    }

    dfs(0);

    for(int i=n+1;i>=1;--i)// 这里是到这跑得,所以应该和后续遍历一样
    {
        for(int j=1;j<=m+1;++j)
        {
            dp[i][j]=max(dp[i+siz[id[i]]][j],dp[i+1][j-1]+wi[id[i]]);
        }
    }

    printf("%lld\n",dp[1][m+1]);

    return 0;
}

\(zzr\) 的刷表

#include <bits/stdc++.h>
using namespace std;
constexpr int maxn=3e2+10;
vector<int> gra[maxn];
int s[maxn];
int siz[maxn];
int xv[maxn];
void dfs(int u)
{
    xv[++xv[0]]=u;
    siz[u]=1;
    for(int v:gra[u])
    {
        dfs(v);
        siz[u]+=siz[v];
    }
}
int dp[maxn][maxn];
int main()
{
    int n,m;
    cin>>n>>m;
    ++m;// 预留0节点
    for(int i=1;i<=n;++i)
    {
        int k;
        cin>>k>>s[i];
        gra[k].emplace_back(i);
    }
    memset(dp,-0x3f,sizeof(dp));// 初始化
    dfs(0);
    dp[1][0]=0;
    for(int i=1;i<=n+1;++i)
    {
        for(int j=0;j<=m;++j)// 刷表
        {
            dp[i+siz[xv[i]]][j]=max(dp[i+siz[xv[i]]][j],dp[i][j]);
        }
        for(int j=1;j<=m;++j)
        {
            dp[i+1][j]=max(dp[i+1][j],dp[i][j-1]+s[xv[i]]);
        }
    }
    int res=0;
    for(int i=0;i<=m;++i)
    {
        res=max(res,dp[n+2][i]);
    }
    cout<<res;
}
posted @ 2025-11-25 21:04  玖玮  阅读(13)  评论(0)    收藏  举报