BZOJ4182. Shopping
【题意】
给定一个树,每个点是一个商店,有di个物品,体积为ci,价值为wi,现有m元,求树上的一条链,使得链上的点都买了至少一个,使得总价值最大
【分析】
首先不考虑树上的链,那么问题就是转换为了普通的多重背包问题
现在考虑如何计算所有链的情况,显然直接枚举的总时间复杂度为$O(n^3m)$无法接受
一般这种树上的路径/点对的问题都可以考虑使用点分治来解决
我们考虑每次只计算子树到当前分治中心这一段的价值,然后考虑合并即可
也就是每次把重心作为必选的节点
这样做的时间复杂度就优化为了O(nmlogn)
【代码】
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int maxn=505; const int maxm=5005; int n,m,head[maxn],tot; int w[maxn],c[maxn],d[maxn]; struct edge { int to,nxt; }e[maxn<<1]; void add(int x,int y) { e[++tot].to=y; e[tot].nxt=head[x]; head[x]=tot; } int root,size,gsiz,siz[maxn],vis[maxn]; void findrt(int u,int fa) { int maxsiz=0; siz[u]=1; for(int i=head[u];i;i=e[i].nxt) { int to=e[i].to; if(to==fa || vis[to]) continue; findrt(to,u); siz[u]+=siz[to]; maxsiz=max(maxsiz,siz[to]); } maxsiz=max(maxsiz,size-siz[u]); if(maxsiz<gsiz) { gsiz=maxsiz; root=u; } } int dfn[maxn],dfstime,rev[maxn]; void dfs(int u,int fa=0) { siz[u]=1; for(int i=head[u];i;i=e[i].nxt) { int to=e[i].to; if(to==fa || vis[to]) continue; dfs(to,u); siz[u]+=siz[to]; } dfn[u]=++dfstime; rev[dfstime]=u; } int f[maxn][maxm],pos[maxm],q[maxm]; int ans; void work(int u) { dfstime=0; dfs(u); for(int i=1;i<=dfstime;i++) { int x=rev[i],head,tail; for(int j=0;j<c[x];j++) { head=1,tail=0; int gs=(m-j)/c[x]; for(int k=0;k<=gs;k++) { int v=k*c[x]+j,y=f[i-1][v]-k*w[x]; while(head<=tail && pos[head]<k-d[x]) head++; if(head<=tail) f[i][v]=max(q[head]+k*w[x],f[i-siz[x]][v]); else f[i][v]=f[i-siz[x]][v]; while(head<=tail && y>=q[tail]) tail--; q[++tail]=y; pos[tail]=k; } } } ans=max(ans,f[dfstime][m]); } void solve(int u) { vis[u]=1; work(u); for(int i=head[u];i;i=e[i].nxt) { int to=e[i].to; if(vis[to]) continue; gsiz=size=siz[to]; root=u; findrt(to,u); solve(root); } } int main() { freopen("a.in","r",stdin); freopen("a.out","w",stdout); int T;scanf("%d",&T); while(T--) { tot=ans=0; memset(head,0,sizeof(head)); scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) scanf("%d",&w[i]); for(int i=1;i<=n;i++) scanf("%d",&c[i]); for(int i=1;i<=n;i++) scanf("%d",&d[i]); int x,y; for(int i=1;i<n;i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); } // return 0; memset(vis,0,sizeof(vis)); size=gsiz=n; root=0; findrt(1,0); // return 0; solve(root); printf("%d\n",ans); } return 0; }