[学习笔记]dsu on tree

这种不怎么难写的东西,我学得快忘得也快,也是给自己加深印象,同时留个自己(大概)能看懂的讲解好复习……qwq

先说是什么

dsu on tree中的dsu就是Disjoint Set Union,虽然整个算法跟并茶几(话说并茶几名字好多啊……)没有任何关系……硬要说就是借用了启发式合并的思想吧……

这个算法是拿来解决树上对子树内答案的询问的,当然它并不支持修改

它在暴力的基础上,借助轻重链剖分的性质把复杂度降低到了\(O(n \log n)\)

大致过程

遍历每一个节点,先递归解决轻儿子,完成后消除递归产生的影响

然后解决重儿子,但不消除影响

将轻儿子的答案合并上来

消除上一个过程中轻儿子产生的影响

拿一个例题来说:CF600E

题意:树上每个节点有一个颜色,求每棵子树中出现次数最多的颜色(可能有多个)之和

首先是轻重链剖分,处理出每个节点的重儿子

void dfs(int u, int fa) {
	size[u] = 1;
	for (int i = G.head[u]; ~i; i = G[i].next) {
		int v = G[i].v;
		if (v == fa) continue;
		dfs(v, u);
		size[u] += size[v];
		if (!heavy[u] || size[v] > size[heavy[u]]) heavy[u] = v;
	}
}

然后遍历节点,按上面的流程来(具体看注释)

void update(int u, int fa, int val, const int &hvy) {//val为1暴力合并统计轻儿子的答案,为-1清除对cnt的影响
	cnt[col[u]] += val;
	if (val > 0 && cnt[col[u]] >= max_cnt) {
		if (cnt[col[u]] > max_cnt) sum = 0, max_cnt = cnt[col[u]];
		sum += (LL)col[u];
	}
	for (int i = G.head[u]; ~i; i = G[i].next) {
		int v = G[i].v;
		if (v == fa || v == hvy) continue;
		update(v, u, val, hvy);
	}
}
void dfs(int u, int fa, int opt) {//opt为0表示需要清除掉u的影响,为1表示不需要
	for (int i = G.head[u]; ~i; i = G[i].next) {
		int v = G[i].v;
		if (v == fa || v == heavy[u]) continue;
		dfs(v, u, 0);//递归解决轻儿子,完成后清除影响
	}
	if (heavy[u]) dfs(heavy[u], u, 1);//解决重儿子,保留影响
	update(u, fa, 1, heavy[u]);//合并轻儿子的答案
	ans[u] = sum;
	if (!opt) update(u, fa, -1, 0), sum = 0, max_cnt = 0;//清除影响
}

最后是完整代码:

PS.怎么网上的博客代码个个不一样啊……,蒟蒻我懵逼了好长时间才看明白qwq

#include <cstdio>
#include <cstring>
#include <iostream>
#define MAXN 100005

typedef long long LL;
struct Graph {
	struct Edge {
		int v, next;
		Edge(int _v = 0, int _n = 0):v(_v), next(_n) {}
	} edge[MAXN << 1];
	int head[MAXN], cnt;
	void init() { memset(head, -1, sizeof head); cnt = 0; }
	void add_edge(int u, int v) { edge[cnt] = Edge(v, head[u]); head[u] = cnt++; }
	void insert(int u, int v) { add_edge(u, v); add_edge(v, u); }
	Edge & operator [](int x) { return edge[x]; }
} G;
int col[MAXN], size[MAXN], heavy[MAXN], val[MAXN], cnt[MAXN], N;
LL sum, max_cnt, ans[MAXN];

void dfs(int, int);
void dfs(int, int, int);
void update(int, int, int, const int &);
int main() {
	G.init();
	scanf("%d", &N);
	for (int i = 1; i <= N; ++i) scanf("%d", col + i);
	for (int i = 1; i < N; ++i) {
		int u, v;
		scanf("%d%d", &u, &v);
		G.insert(u, v);
	}
	dfs(1, 0);
	dfs(1, 0, 0);
	for (int i = 1; i <= N; ++i) printf("%I64d ", ans[i]);

	return 0;
}
void dfs(int u, int fa) {
	size[u] = 1;
	for (int i = G.head[u]; ~i; i = G[i].next) {
		int v = G[i].v;
		if (v == fa) continue;
		dfs(v, u);
		size[u] += size[v];
		if (!heavy[u] || size[v] > size[heavy[u]]) heavy[u] = v;
	}
}
void update(int u, int fa, int val, const int &hvy) {
	cnt[col[u]] += val;
	if (val > 0 && cnt[col[u]] >= max_cnt) {
		if (cnt[col[u]] > max_cnt) sum = 0, max_cnt = cnt[col[u]];
		sum += (LL)col[u];
	}
	for (int i = G.head[u]; ~i; i = G[i].next) {
		int v = G[i].v;
		if (v == fa || v == hvy) continue;
		update(v, u, val, hvy);
	}
}
void dfs(int u, int fa, int opt) {
	for (int i = G.head[u]; ~i; i = G[i].next) {
		int v = G[i].v;
		if (v == fa || v == heavy[u]) continue;
		dfs(v, u, 0);
	}
	if (heavy[u]) dfs(heavy[u], u, 1);
	update(u, fa, 1, heavy[u]);
	ans[u] = sum;
	if (!opt) update(u, fa, -1, 0), sum = 0, max_cnt = 0;
}
//Rhein_E

最后是复杂度证明

轻重链剖分保证了每个节点到根的路径上轻边条数不超过\(\log n\)

每个节点被访问一次,要么是它的祖先节点暴力统计轻儿子/消除影响的时候,要么是它自己统计答案的时候

前者\(O(\log n)\)次,后者\(1\)

所以总复杂度是\(O(n \log n)\)

posted @ 2019-03-10 11:28  Rhein_E  阅读(996)  评论(0编辑  收藏  举报