一些关于树上背包时间复杂度的问题
你正在做一道关于有根树的题目,你目前的算法的时间复杂度为每一个点的所有儿子的子树大小乘积之和,即 \(\sum(\prod_{v \in son_u} sz_v)\)。看上去似乎是 \(n^3\) 的,但是考虑每对点只有可能在它们的 lca 处对总和贡献 \(1\),所以总复杂度其实是 \(n^2\) 的。
你很快切掉了那题,打开了另一题:P2014 选课。你发现这是一道树上背包的模板题,让你求包含根节点,大小为 \(m\) 的连通块包含的最大点权和,你很快看出了 \(nm^2\) 的朴素做法,通过了此题。你点开题解,发现存在 \(nm\) 的做法,那是一种使用 \(dfn\) 的性质来将所有点的 DP 数组揉到一起计算的方法,减少了对不存在状态的枚举。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=307;
int n,m,s[N],lp[N],rp[N],idx,id[N];
vector<int> g[N];
int dp[N][N];
void dfs(int u){
lp[u]=idx;
for(int v:g[u]) dfs(v);
rp[u]=++idx; id[idx]=u;
}
int main(){
cin>>n>>m;
for(int i=1;i<=n;i++){
int p; cin>>p>>s[i];
g[p].push_back(i);
}
dfs(0);
for(int i=1;i<=n+1;i++){
int x=id[i];
for(int j=0;j<=m;j++){
dp[i][j]=dp[i-rp[x]+lp[x]][j];
if(j) dp[i][j]=max(dp[i][j],dp[i-1][j-1]+s[x]);
}
}
cout<<dp[n][m]<<"\n";
return 0;
}
你很快发现这种做法并不具有强普适性。你在思考是否可以直接在朴素代码上进行修改,使得时间降低至 \(nm\)。
你优化出了如下算法:
//to kill a living book
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 307;
int f[N][N], n, m, a[N], sz[N];
vector<int> g[N];
void dfs(int u){
f[u][1] = a[u];
sz[u] = 1;
for(int v: g[u]){
dfs(v);
for(int i = min(m, sz[u]); i >= 1; i --){
for(int j = min(m - i, sz[v]); j >= 1; j --){
f[u][i + j] = max(f[u][i + j], f[u][i] + f[v][j]);
}
}
sz[u] += sz[v];
}
}
signed main(){
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
cin >> n >> m; m ++;
for(int i = 1; i <= n; i ++){
int fa; cin >> fa >> a[i];
g[fa].push_back(i);
}
dfs(0), cout << f[0][m] << "\n";
return 0;
}
注意转移部分:
for(int i = min(m, sz[u]); i >= 1; i --){
for(int j = min(m - i, sz[v]); j >= 1; j --){
f[u][i + j] = max(f[u][i + j], f[u][i] + f[v][j]);
}
}
你发现,此时复杂度其实是代码中的 \(\sum(\min(sz_u,m) \times \min(sz_v,m))\),其中 \(sz_u\) 代表此时 \(v\) 左边的所有子树大小之和。
同样地,你考虑看下两个点 \(u\),\(v\) 何时会产生 \(1\) 的贡献,你画了个图。

考虑到 \(pre\) 区域中的 \(dfn\) 一定会全部小于 \(sz\) 区域中的 \(dfn\)(中序遍历),你考虑钦定 \(pre\) 中 \(dfn\) 比 \(u\) 大的和 \(sz\) 中 \(dfn\) 比 \(v\) 小的所有点作为 \(\min\) 的限制,换句话说,如果 \(pre\) 区域中 \(dfn\) 不小于 \(u\) 的点数不多于 \(m\) 个,且 \(sz\) 区域中 \(dfn\) 不大于 \(v\) 的点数不多于 \(m\) 个,那么 \(u\),\(v\) 便会对答案产生 \(1\) 的贡献。
注意力惊人地,你发现这其实相当于把每一个点按 \(dfn\) 重编号,然后求两个编号差不超过 \(\mathcal{O}(m)\) 的数对的个数,这显然是 \(O(nm)\) 的。
练习题:AT_abc416_f。
本文来自博客园,作者:GE9x,转载请注明原文链接:https://www.cnblogs.com/GE9X/p/19710449

浙公网安备 33010602011771号