树上依赖性背包 学习笔记 | P6326 Shopping 题解

树上依赖性背包

树上依赖性背包,适用于合并复杂度大,插入复杂度小的情况。


P6326 Shopping

发现题意等价于在树上选一个连通块

完全背包先二进制拆分,然后想出一车假做法。思考正解。直接写 Solution 吧。

#1:点分治 + dfn 序

百家博大学习。WC 好像会了。

一个常见的技巧是,我们树上背包是劣于序列背包的,所以对于这类问题,可以想办法把它拍成序列上的问题。注意到一个根以及它的子树节点的时间戳在 dfn 序列上一定是连续的一段,可以由此入手。

考虑先随便钦定一个根,并且钦定在根至少选 \(1\) 个物品。

然后我们搞出 dfn 序把树拍成序列。由于 dfn 序列上后面的点不可能是前面的点的祖先,考虑倒序设计 dp 简化问题。

\(siz[i]\) 为以 \(i\) 为根子树大小。

\(dp[i][j]\) 为考虑 dfn 序为 \([i,n]\) 的节点,一共花费为 \(j\),所能得到的最大价值。

转移分为两种情况:

  • 不选以 \(i\) 为根子树,包括 \(i\)。那么令 \(dp[i][j]=dp[i+siz_i][j]\) 即可。含义为直接继承这个子树之外 dfn 序上下一个节点的 dp 值。

  • 选以 \(i\) 为根子树,即必须保证 \(i\) 至少选 \(1\) 个。那么 \(dp[i][j]\) 直接可以由 \(dp[i+1][k]\)\(i\) 这个物品做多重背包直接转移即可。

最后答案对 \(dp[1][i]\)\(\max\) 即可。因为根的 dfn 序为 \(1\)

这样做一次 dp 复杂度是 \(O(nm \log m)\) 的。我们还需要对于每个点都钦定为根跑 dp,总复杂度 \(O(n^2m\log m)\) 的。已经比树上直接做是 \(O(nm^2 \log m)\) 优了,但还不够。

发现选的是树上连通块。那么可以考虑直接点分治做,正确性可以保证。

时间复杂度 \(O(n\log m \log m)=O(nm\log n \log m)\),可以通过。注意本题多测。

Code:

WC 了,我代码犯的全是唐诗错误,包括 01 背包正着枚举,\(dp[i+siz_{dfn[i]]}\) 写成 \(dp[i+siz_i]\)

#include<bits/stdc++.h>
#define int long long

using namespace std;

const int Size=(1<<20)+1;
char buf[Size],*p1=buf,*p2=buf;
char buffer[Size];
int op1=-1;
const int op2=Size-1;
#define getchar()                                                              \
(tt == ss && (tt=(ss=In)+fread(In, 1, 1 << 20, stdin), ss == tt)     \
	? EOF                                                                 \
	: *ss++)
char In[1<<20],*ss=In,*tt=In;
inline int read()
{
	int x=0,c=getchar(),f=0;
	for(;c>'9'||c<'0';f=c=='-',c=getchar());
	for(;c>='0'&&c<='9';c=getchar())
		x=(x<<1)+(x<<3)+(c^48);
	return f?-x:x;
}
inline void write(int x)
{
	if(x<0) x=-x,putchar('-');
	if(x>9)  write(x/10);
	putchar(x%10+'0');
}

const int N=505;
vector<int> v[N],E[N];
int n,m;
int w[N];
int c[N];
int d[N];

void chai(int id,int x)
{
    if(!x) return;
    v[id].clear();
    for(int k=0;;k++)
    {
        int nw=1ll<<k;
        if(nw>x)
        {
            nw>>=1;
            v[id].pop_back();
            v[id].push_back(x-nw+1);
            return;
        }
        v[id].push_back(nw);
    }
}

bool vis[N];
int siz[N];
pair<int,int> findG(int p,int fa,int tot)
{
    pair<int,int> nw=make_pair(tot-siz[p],p),ans=make_pair(n+1,-1);
    for(int to:E[p])
    {
        if(to==fa) continue;
        if(vis[to]) continue;
        ans=min(ans,findG(to,p,tot));
        nw.first=max(nw.first,siz[to]);
    }
    return min(ans,nw);
}

void dfs(int p,int fa)
{
    siz[p]=1;
    for(int to:E[p])
    {
        if(to==fa) continue;
        if(vis[to]) continue;
        dfs(to,p);
        siz[p]+=siz[to];
    }
}

int tot;
int id[N];
int dp[N][4005];
void dfs1(int p,int fa)
{
    siz[p]=1;
    id[++tot]=p;
    for(int to:E[p])
    {
        if(to==fa) continue;
        if(vis[to]) continue;
        dfs1(to,p);
        siz[p]+=siz[to];
    }
}

int ans=0;

void dodp(int root)
{
    int nw=ans;
    tot=0;
    dfs1(root,0);
    for(int i=1;i<=tot+1;i++) memset(dp[i],0,sizeof(dp[i]));

    for(int i=tot;i>=1;i--)
    {
        int p=id[i];
        for(int j=c[p];j<=m;j++) dp[i][j]=dp[i+1][j-c[p]]+w[p];
        for(int cnt:v[p])
        {
            int V=cnt*c[p],W=cnt*w[p];
            for(int j=m;j>=V;j--)
                dp[i][j]=max(dp[i][j],dp[i][j-V]+W);
        }

        for(int j=1;j<=m;j++) dp[i][j]=max(dp[i+siz[p]][j],dp[i][j]);
    }
    for(int i=1;i<=m;i++) ans=max(ans,dp[1][i]);
}

void solve(int root)
{
    vis[root]=1;
    for(int to:E[root])
    {
        if(vis[to]) continue;
        dfs(to,0);
        int nxt=findG(to,0,siz[to]).second;
        solve(nxt);
    }
    dodp(root);
    vis[root]=0;
}

void solve()
{
    n=read();
    m=read();
    for(int i=1;i<=n;i++) w[i]=read();
    for(int i=1;i<=n;i++) c[i]=read();
    for(int i=1;i<=n;i++) d[i]=read();
    for(int i=1;i<=n;i++) chai(i,d[i]-1),E[i].clear();
    for(int i=1;i<n;i++)
    {
        int u=read(),v=read();
        E[u].push_back(v);
        E[v].push_back(u);
    }

    dfs(1,0);
    int root=findG(1,0,siz[1]).second;
    solve(root);

    cout<<ans<<"\n";
    ans=0;
}


signed main()
{
    int T=read();
    while(T--) solve();
	return 0;
}

#2:树上启发式合并。

(Waiting)


AI 润色

AI 润色之后 # 树上依赖性背包

适用于合并操作复杂度较高,而插入操作复杂度较低的场景。


P6326 Shopping

在分析题目时,我们发现题目的本质等价于在树中选取一个连通块

首先,将完全背包问题利用二进制拆分处理,再探索多种简化策略,最终确定了有效的解决方案。下面介绍一种基于动态规划(DP)的思路。

方法一:点分治 + dfn 序列

参考百家博等人的思路,主要步骤如下:

  1. 问题转化:树上背包问题通常不如序列背包问题高效,因此可以尝试将树转化为序列。注意,一个节点及其子树在 DFS 得到的 dfn 序列中通常是连续的。
  2. 选取根节点:任选一个根节点,并保证该根至少选择一个物品。
  3. 构造 dfn 序列并逆序 DP:利用 DFS 得到的 dfn 序列,由于序列后面的节点不可能是前面节点的祖先,可以采用逆序 DP 来进行状态转移。

设定如下变量:

  • \(siz[i]\):以 \(i\) 为根的子树大小。
  • \(dp[i][j]\):在 dfn 序列中,从第 \(i\) 个节点开始,当总花费不超过 \(j\) 时所能获得的最大价值。

状态转移分析:

  • 不选方案:不选取以 \(i\) 为根的子树,此时有 \(dp[i][j] = dp[i+siz[i]][j]\)
  • 选取方案:选择以 \(i\) 为根的子树,必须至少选中一个物品,通过多重背包思想更新 \(dp[i][j]\)

最终答案由 \(dp[1][j]\) 的最大值确定,其中索引 1 对应树的根。

该方法单次 DP 的时间复杂度为 \(O(nm\log m)\),但由于需要对每个节点作为根进行处理,总体复杂度为 \(O(n^2m\log m)\)。相比直接在树上求解(\(O(nm^2\log m)\))已经有所提升,但仍有优化空间。

进一步地,由于题目要求选取的是树上连通块,可以直接采用点分治策略将时间复杂度降低至 \(O(nm\log n\log m)\),从而满足较大数据规模的要求(注意测试数据较多)。

示例代码

在编码过程中,常见错误包括:

  • 01 背包中正向枚举导致错误;
  • 状态转移时容易将 \(dp[i+siz[dfn[i]]\) 错写为 \(dp[i+siz[i]]\)

以下为主要代码实现:

#include<bits/stdc++.h>
#define int long long

using namespace std;

const int Size = (1<<20) + 1;
char buf[Size], *p1 = buf, *p2 = buf;
char buffer[Size];
int op1 = -1;
const int op2 = Size - 1;
#define getchar() (tt == ss && (tt=(ss=In)+fread(In, 1, 1<<20, stdin), ss == tt) ? EOF : *ss++)
char In[1<<20], *ss = In, *tt = In;

inline int read() {
    int x = 0, c = getchar(), f = 0;
    for(; c > '9' || c < '0'; f = (c=='-'), c = getchar());
    for(; c >= '0' && c <= '9'; c = getchar())
        x = (x << 1) + (x << 3) + (c ^ 48);
    return f ? -x : x;
}

inline void write(int x) {
    if(x < 0) x = -x, putchar('-');
    if(x > 9) write(x/10);
    putchar(x % 10 + '0');
}

const int N = 505;
vector<int> v[N], E[N];
int n, m;
int w[N], c[N], d[N];

void chai(int id, int x) {
    if(!x) return;
    v[id].clear();
    for(int k = 0;; k++) {
        int nw = 1ll << k;
        if(nw > x) {
            nw >>= 1;
            v[id].pop_back();
            v[id].push_back(x - nw + 1);
            return;
        }
        v[id].push_back(nw);
    }
}

bool vis[N];
int siz[N];

pair<int,int> findG(int p, int fa, int tot) {
    pair<int,int> nw = make_pair(tot - siz[p], p), ans = make_pair(n+1, -1);
    for(int to : E[p]) {
        if(to == fa || vis[to]) continue;
        ans = min(ans, findG(to, p, tot));
        nw.first = max(nw.first, siz[to]);
    }
    return min(ans, nw);
}

void dfs(int p, int fa) {
    siz[p] = 1;
    for(int to : E[p]) {
        if(to == fa || vis[to]) continue;
        dfs(to, p);
        siz[p] += siz[to];
    }
}

int tot, id[N];
int dp[N][4005];

void dfs1(int p, int fa) {
    siz[p] = 1;
    id[++tot] = p;
    for(int to : E[p]) {
        if(to == fa || vis[to]) continue;
        dfs1(to, p);
        siz[p] += siz[to];
    }
}

int ans = 0;

void dodp(int root) {
    tot = 0;
    dfs1(root, 0);
    for(int i = 1; i <= tot + 1; i++) memset(dp[i], 0, sizeof(dp[i]));
    for(int i = tot; i >= 1; i--) {
        int p = id[i];
        for(int j = c[p]; j <= m; j++)
            dp[i][j] = dp[i + 1][j - c[p]] + w[p];
        for(int cnt : v[p]) {
            int V = cnt * c[p], W = cnt * w[p];
            for(int j = m; j >= V; j--)
                dp[i][j] = max(dp[i][j], dp[i][j - V] + W);
        }
        for(int j = 1; j <= m; j++)
            dp[i][j] = max(dp[i][j], dp[i + siz[p]][j]);
    }
    for(int i = 1; i <= m; i++) ans = max(ans, dp[1][i]);
}

void solve(int root) {
    vis[root] = true;
    for(int to : E[root]) {
        if(vis[to]) continue;
        dfs(to, 0);
        int nxt = findG(to, 0, siz[to]).second;
        solve(nxt);
    }
    dodp(root);
    vis[root] = false;
}

void solve() {
    n = read();
    m = read();
    for(int i = 1; i <= n; i++) w[i] = read();
    for(int i = 1; i <= n; i++) c[i] = read();
    for(int i = 1; i <= n; i++) d[i] = read();
    for(int i = 1; i <= n; i++) {
        chai(i, d[i] - 1);
        E[i].clear();
    }
    for(int i = 1; i < n; i++) {
        int u = read(), v = read();
        E[u].push_back(v);
        E[v].push_back(u);
    }
    dfs(1, 0);
    int root = findG(1, 0, siz[1]).second;
    solve(root);
    cout << ans << "\n";
    ans = 0;
}

signed main() {
    int T = read();
    while(T--) solve();
    return 0;
}

方法二:树上启发式合并

(内容待更新)

posted @ 2025-10-21 19:03  Wy_x  阅读(19)  评论(0)    收藏  举报