[树形dp] CF1280D Miss Punyverse

posted on 2024-04-24 05:38:12 | under | source

显然先令 \(a_i=w_i-b_i\),那么判定 \(\sum a\) 是否是正数即可。

然后反悔贪心?试了下发现不太行。

观察到数据范围开了个标准的 \(n^2\),于是考虑树形 \(\rm dp\)。定义 \(f_{u,i}\) 表示 \(u\) 子树内,分为 \(i\) 个连通块时,合法连通块有几个。

不过还需记录此时最上面连通块的点权和,\(f\) 数组的规模就会到达 \(n^3\),不太行。

但是结合做题经验,我们猜测 \(f_{u,i}\) 的大小只有 \(1\),也就是说可以取到一个绝对最优的值。

试了下,发现确实可行。不过需要改变 \(f_{u,i}\) 定义:分为 \(i\) 个连通块,除最上面的连通块外的合法连通块数量 \(cnt\),并记下最上面的连通块权值和 \(sum\)

可以保证,\(cnt_1>cnt_2\) 时,\(f_1\) 一定不劣于 \(f_2\)。相等时取 \(\max sum\) 即可。

转移很简单(记得卡好枚举的上界),于是这道题就做完了。

但提交前,你可能会注意到 \(N=\sum n\le 100000\) 这一点。不过,因为单组数据 \(O(n^2)\),令所有 \(n\) 相等,总共 \(O(n^2\frac Nn)=O(nN)\),也就是 \(\rm 3e8\) 左右了,可过。

代码

#include<bits/stdc++.h>
using namespace std;

#define LL long long 
#define pir pair<int, LL>
const int N = 3e3 + 5;
int T, n, m, _a, a[N], siz[N], u, v;
pir f[N][N], g[N]; 
vector<int> to[N];

inline pir mer(pir A, pir B) {return A.first == B.first ? pir{A.first, max(A.second, B.second)} : (A.first > B.first ? pir{A.first, A.second} : pir{B.first, B.second});}
inline void dfs(int u, int fa){
	f[u][1] = {0, a[u]}, siz[u] = 1;
	for(int i = 2; i <= m; ++i) f[u][i] = {-1, 0};
	for(auto v : to[u])
		if(v ^ fa){
			dfs(v, u);
			for(int i = 1; i <= min(m, siz[u] + siz[v]); ++i) g[i] = {-1, 0};
			for(int i = 1; i <= min(m, siz[u]); ++i)
				for(int j = 1; j <= siz[v] && i + j - 1 <= m; ++j){
					//不合并 
					if(i + j <= m)
						g[i + j] = mer(g[i + j], {f[u][i].first + f[v][j].first + (f[v][j].second > 0), f[u][i].second});
					//合并 
					if(i + j - 1 <= m)
						g[i + j - 1] = mer(g[i + j - 1], {f[u][i].first + f[v][j].first, f[u][i].second + f[v][j].second}); 
				}
			siz[u] += siz[v];
			for(int i = 1; i <= min(m, siz[u] + siz[v]); ++i) f[u][i] = g[i];
		}
}
signed main(){
	cin >> T;
	while(T--){
		scanf("%d%d", &n, &m);
		for(int i = 1; i <= n; ++i) scanf("%d", &a[i]), to[i].clear();
		for(int i = 1; i <= n; ++i) scanf("%d", &_a), a[i] = _a - a[i];
		for(int i = 1; i < n; ++i) scanf("%d%d", &u, &v), to[u].push_back(v), to[v].push_back(u);
		dfs(1, 0); 
		printf("%d\n", f[1][m].first + (f[1][m].second > 0));
	}
	return 0;
}
posted @ 2026-01-13 11:16  Zwi  阅读(0)  评论(0)    收藏  举报