P5521 题解

一道比较不错的思维题。

对于树上的每一个节点,我们考虑设节点 \(i\) 要放 \(w_i\) 朵梅花,如果从梅岭带出 \(ans_i\) 朵梅花,就在节点 \(i\) 上放 \(w_i\) 朵梅花。

具体地,有以下两种情况:

第一种情况,梅花直接放所有子节点再放父节点。则易知 \(w_i+\sum w_j\)

第二种情况,如果已经在节点 \(j\) 放了梅花,显然都会将其子节点的梅花收走。那么放一个子节点的花费为:

\[\displaystyle\sum_{j=1}^{k-1}w_i+ans_k \]

因为要放所有子节点,所以取其最大值。则有:

\[ans_i=\max\left\{w_i+\sum w_j,\max_{k=1\rightarrow n}\left\{\sum_{j=1}^{k-1}w_j+ans_k\right\}\right\} \]

注意到 \(ans_i\) 与子节点顺序有联系,那么就需要排序。

这里给出结论:直接按 \(ans_i-w_i\) 从大到小排序。

证明:

设两个节点 \(i\)\(j\) 相邻,且交换前的值是 \(\max\left\{ans_j+w_i,ans_i\right\}\),交换后的值是 \(\max\left\{ans_i+w_j,ans_j\right\}\)

若满足条件 \(ans_i-w_i>ans_j-w_j\),则有 \(ans_j+w_i<ans_i+w_j\)

因为 \(w\in\N^+\),所以有:

\[\begin{cases} ans_j+w_i>ans_j \\ ans_i+w_j>ans_i \\ ans_j+w_i>ans_i \\ \end{cases} \]

显然原顺序优于交换顺序。

证毕。

至于时间复杂度的话,设树上每一个节点与 \(m\) 个节点相邻,则总的时间复杂度为 \(O(nm\log n)\),可以通过本题(或许可以钦定为 \(O(n)\)?)。

#include <iostream>
#include <algorithm>
#include <vector>
#define MAXN 100005
using namespace std;
int n, p;
struct edge{int w, to, nxt;}e[MAXN << 1];
int head[MAXN], cnt = 1;
struct node{
	int ans, w;
	bool friend operator<(node a, node b){
		return a.ans - a.w > b.ans - b.w;
	};
}a[MAXN];
int read(){
	int t = 1, x = 0;char ch = getchar();
	while(!isdigit(ch)){if(ch == '-')t = -1;ch = getchar();}
	while(isdigit(ch)){x = (x << 1) + (x << 3) + (ch ^ 48);ch = getchar();}
	return x * t;
}
void write(int x){
	if(x < 0){putchar('-');x = -x;}
	if(x >= 10)write(x / 10);
	putchar(x % 10 ^ 48);
}
void add(int u, int v, int w){
	cnt++;e[cnt].to = v;e[cnt].nxt = head[u];head[u] = cnt;
}
void dfs(int now){
	vector <node> tmp;tmp.clear();a[now].ans = a[now].w;
	for(int i = head[now] ; i != 0 ; i = e[i].nxt){
		int v = e[i].to;dfs(v);
		tmp.push_back(a[v]);
		a[now].ans += a[v].w;
	}
	sort(tmp.begin(), tmp.end());int tot = 0;
	for(int i = 0 ; i < tmp.size() ; i ++){
		a[now].ans = max(a[now].ans, tmp[i].ans + tot);
		tot += tmp[i].w;
	}
}
int main(){
	n = read();
	for(int i = 2 ; i <= n ; i ++)p = read(),add(p, i, 0);
	for(int i = 1 ; i <= n ; i ++)a[i].w = read();
	dfs(1);
	for(int i = 1 ; i <= n ; i ++){
		if(i != 1)putchar(' ');
		write(a[i].ans);
	}
 	putchar('\n');return 0;
}
posted @ 2023-11-09 16:46  tsqtsqtsq  阅读(50)  评论(0)    收藏  举报