HDU-6035 Colorful Tree(树形DP) 2017多校第一场
题意:给出一棵树,树上的每个节点都有一个颜色,定义一种值为两点之间路径中不同颜色的个数,然后一棵树有n*(n-1)/2条
路径,求所有的路径的值加起来是多少。
思路:比赛的时候感觉是树形DP,但是脑袋抽了,忘记树形DP是怎么遍历的了(其实没忘也不会做:)
先给出官方题解吧:
单独考虑每一种颜色,答案就是对于每种颜色至少经过一次这种的路径条数之和。反过来思考只需要求有多少条路径没有经过这种颜色即可。直接做可以采用虚树的思想(不用真正建出来),对每种颜色的点按照 dfs 序列排个序,就能求出这些点把原来的树划分成的块的大小。这个过程实际上可以直接一次 dfs 求出。
其实感觉就前面那两句能看懂= =,不过这个也才是主要的。很显然这题是算每个节点的贡献,但是每个节点都算的话就会算重,去重不容易,所以就可以计算
有哪些路径没有经过这种颜色,这个就很容易一点了,所以对于每种颜色就是求他的联通快的大小。
然后感觉还是不容易,可能是树形DP做少了吧,通过这个题还是学到不少的技巧 比如通过访问节点前,把前面的颜色给存起来,这样到时候访问的就是这颗子树的内容了,更新这些内容的时候一定要注意他们之间的关系
/** @xigua */
#include<stdio.h>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#include<stack>
#include<cstring>
#include<queue>
#include<set>
#include<string>
#include<map>
#define PI acos(-1)
using namespace std;
typedef long long ll;
typedef double db;
const int maxn = 2e5 + 5;
const ll maxm = 1e7;
const int mod = 1e9 + 7 + 0.1;
const int INF = 1e9 + 7;
const ll inf = 1e15 + 5;
const db eps = 1e-9;
const int state = 15;
ll ans, col[maxn];
int head[maxn], cnt, n, siz[maxn], vis[maxn], c[maxn];
struct Edge {
int v, next;
} e[maxn<<1];
void add(int u, int v) {
e[cnt].v = v;
e[cnt].next = head[u];
head[u] = cnt++;
}
void init() {
cnt = ans = 0;
memset(head, -1, sizeof(head));
memset(vis, 0, sizeof(vis));
memset(col, 0, sizeof(col));
}
void dfs(int u, int fa) {
siz[u] = 1;
vis[c[u]] = 1;
int pre = col[c[u]]; //访问节点前的
int num = 0; //因为一个节点有好几个儿子,所以就要把前面的给剔除
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].v;
if (fa == v) continue;
dfs(v, u);
siz[u] += siz[v];
ll tmp = siz[v] - (col[c[u]] - pre - num);
ans -= (tmp - 1) * tmp / 2;
num = col[c[u]] - pre; //num就是当前颜色的个数减去以前的就是节点u的前面的儿子的
}
col[c[u]] = pre + siz[u];
}
void solve() {
int cas = 1;
while (cin >> n) {
init();
for (int i = 1; i <= n; i++)
scanf("%d", c + i);
for (int i = 1; i < n; i++) {
int u, v; scanf("%d%d", &u, &v);
add(u, v); add(v, u);
}
dfs(1, -1);
for (int i = 1; i <= n; i++) {
if (vis[i]) {
ans += (ll)n * (n - 1) / 2;
ll tmp = n - col[i];
ans -= tmp * (tmp - 1) / 2;
}
}
printf("Case #%d: %I64d\n", cas++, ans);
}
}
int main() {
int t = 1, cas = 1;
//freopen("in.txt", "r", stdin);
// scanf("%d", &t);
// init();
while(t--) {
// printf("Case #%d:\n", cas++);
solve();
}
return 0;
}

浙公网安备 33010602011771号