[USACO10MAR] 伟大的奶牛聚集

题目类型:树形\(dp\)

传送门:>Here<

题意:给出一棵有边权树,每个节点有\(c[i]\)个人。现在要求所有人聚集到一个点去,代价为每个人走的距离之和。问选哪个点?

解题思路

暴力做法:枚举聚集点,再\(O(n)\)计算每个点到它的距离,还得用\(lca\)求,复杂度\(O(n^2logn)\)

暴力做法2:我们考虑\(O(n)\)维护一个数组\(t[i]\),表示节点\(i\)的子树内所有人到\(i\)的路程之和。易知根节点的\(t\)值就是聚集到根节点时的答案。转\(n\)次根重新遍历,打擂,复杂度\(O(n^2)\)

其实正解就是对暴力做法的一个改进。暴力做法之所以慢,是因为每到一个新的点\(t\)都要重新计算。没有充分利用历史信息。

我们发现对于所有根节点的子节点,不过是子树外的节点多走了这一条边,子树内的节点少走了这一条边。因此就可以完成\(O(1)\)转移了。$$dp[v] = dp[u] + (TotSize - size[v]) * cost[i] - size[v] * cost[i];$$

反思

这题不太像普通的树形\(dp\),一般的树形\(dp\)根节点的值都由子树转移来。这道题却让子树的值由根节点转移。逆向思维。

我做这道题的盲点在于我一直在考虑\(dp[i]\)表示所有节点到子树\(i\)内的一个节点的最小值。事实上子树内这个概念搞得非常玄也非常难搞,干脆定在\(i\)上有时候是一种更好的思路。如果我能够想到直接定在\(i\)上,也就不难想出转移了。

Code

inf开大,自己很快就调出来了。调试嘛,可能的,老犯的错误也就那么几种。

/*By DennyQi 2018*/
#include <cstdio>
#include <queue>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
#define int ll
const int MAXN = 100010;
const int MAXM = 200010;
const int INF = 106110956700000;
inline int Max(const int a, const int b){ return (a > b) ? a : b; }
inline int Min(const int a, const int b){ return (a < b) ? a : b; }
inline int read(){
    int x = 0; int w = 1; register char c = getchar();
    for(; c ^ '-' && (c < '0' || c > '9'); c = getchar());
    if(c == '-') w = -1, c = getchar();
    for(; c >= '0' && c <= '9'; c = getchar()) x = (x<<3) + (x<<1) + c - '0'; return x * w;
}
int N,x,y,z,TotSize;
int c[MAXN],t[MAXN],size[MAXN],dp[MAXN];
int first[MAXN],nxt[MAXM],to[MAXM],cost[MAXM],cnt;
inline void add(int u ,int v, int w){
	to[++cnt] = v, cost[cnt] = w, nxt[cnt] = first[u], first[u] = cnt;
}
void Dfs(int u, int Fa){
	int v;
	size[u] = c[u];
	for(int i = first[u]; i; i = nxt[i]){
		if((v = to[i]) == Fa) continue;
		Dfs(v, u);
		size[u] += size[v];
		t[u] += t[v] + size[v] * cost[i];
	}
}
void Dp(int u, int Fa){
	int v;
	for(int i = first[u]; i; i = nxt[i]){
		if((v = to[i]) == Fa) continue;
		dp[v] = dp[u] + (TotSize - size[v]) * cost[i] - size[v] * cost[i];
		Dp(v, u);
	}
}
#undef int
int main(){
#define int ll
	N = read();
	for(int i = 1; i <= N; ++i){
		c[i] = read();
		TotSize += c[i];
	}
	for(int i = 1; i < N; ++i){
		x = read(), y = read(), z = read();
		add(x, y, z);
		add(y, x, z);
	}
	Dfs(1, 0);
	dp[1] = t[1];
	Dp(1, 0);
	int ans(INF);
	for(int i=  1; i <= N; ++i){
		ans = Min(ans, dp[i]);
	}
	printf("%lld", ans);
	return 0;
}
posted @ 2018-10-26 11:30  DennyQi  阅读(189)  评论(0编辑  收藏  举报