Boruvka 学习笔记

生成树的另类算法

Boruvka 算法

神秘资料

考虑 \(n\) 个点,对于每个点找出最短的边,若长度相同则选编号小的,

假设这条边从 \(u\) 连向了 \(v\),则将 \(u\)\(v\) 合并。

进行完一轮合并后,原图会形成一些连通块。

接下来对于每个连通块,找到一条连向别的连通块的最短的边,同样若长度相同则选编号小的,

若一个连通块选出了从 \(u\)\(v\) 的边,则将 \(u\)\(v\) 所在的连通块合并。

重复做这个过程,直到所有点连通。

例题

CF888G - Xor-MST

给定 \(n\) 个结点的无向完全图。每个点有一个点权为 \(a_i\)

连接 \(i\) 号结点和 \(j\) 号结点的边的边权为 \(a_i \oplus a_j\)

求这个图的 MST 的权值。

\(1 \le n \le2\times 10^5\)\(0 \le a_i <2^{30}\)

考虑 Boruvka 算法,考虑每次怎么找到最短的边,考虑 01trie。

对于某一个连通块中的某个点 \(u\),需要找到不在这个连通块里的 \(v\),使得 \(a_u \oplus a_v\) 最小,

那么就需要除了这个连通块的 01trie,这个可以通过维护整体的 01trie 和每个连通块的 01trie,

两颗 01trie 相减即可得到除了这个连通块的 01trie。

合并的时候进行类似线段树合并操作即可。

时间复杂度 \(O(n \log n \log V)\)

const int N = 2e5 + 5;
const int M = N * 50;
const int INF = 2e9 + 7;
int n, a[N];
int tr[M][2], idx;
int sz[M], b[M];
int rt[N], w[N], p[N], f[N];
int find(int u) {
	if(f[u] == u) return u;
	return f[u] = find(f[u]);
}
void insert(int & u, int d, int i, int x) {
	if(! u) u = ++ idx;
	sz[u]++;
	if(d < 0) {
		b[u] = i;
		return;
	}
	int v = x >> d & 1;
	insert(tr[u][v], d - 1, i, x);
}
PII query(int u, int v, int x) {
	int res = 0;
	ROF(i, 30, 0) {
		int c = x >> i & 1;
		if(sz[tr[u][c]] - sz[tr[v][c]] > 0) {
			u = tr[u][c];
			v = tr[v][c];
		}
		else {
			u = tr[u][c ^ 1];
			v = tr[v][c ^ 1];
			res |= (1 << i);
		}
	}
	return {res, b[u]};
}
void merge(int & u, int v) {
	if(! u || ! v) {
		u = u | v;
		return;
	}
	merge(tr[u][0], tr[v][0]);
	merge(tr[u][1], tr[v][1]);
	sz[u] = sz[tr[u][0]] + sz[tr[u][1]];
}
void solve() {
	cin >> n;
	FOR(i, 1, n) cin >> a[i];
	sort(a + 1, a + n + 1);
	n = unique(a + 1, a + n + 1) - a - 1;
	FOR(i, 1, n) insert(rt[0], 30, i, a[i]);
	FOR(i, 1, n) insert(rt[i], 30, i, a[i]);
	FOR(i, 1, n) f[i] = i;
	ll ans = 0;
	while(1) {
		int fl = 0;
		FOR(i, 1, n) w[i] = INF;
		FOR(i, 1, n) {
			int u = find(i);
			auto h = query(rt[0], rt[u], a[i]);
			int v = find(SE(h));
			if(u == v) continue;
			if(FI(h) < w[u]) w[u] = FI(h), p[u] = v;
			if(FI(h) < w[v]) w[v] = FI(h), p[v] = u;
		}
		FOR(i, 1, n) {
			if(w[i] < INF && find(i) != find(p[i])) {
				ans += w[i];
				fl = 1;
				merge(rt[find(i)], rt[find(p[i])]);
				f[find(p[i])] = find(i);
			} 
		}
		if(!fl) break;
	}
	cout << ans << endl;
}

[COTS 2016] 建造费 Pristojba

有一张 \(n\) 个点的简单无向图 \(G\)

给定数列 \(p\),边 \((i,j)\)\(i\neq j\))的边权为 \(p_i+p_j\)

然而,不是所有 \(i,j\) 间都有边连接。给定 \(m\) 个三元组形如 \(x,l,r\),表示 \(\forall l\le y\le r\)\(x,y\) 间有边连接。

求出这张无向图的最小生成树的边权和。

考虑 Boruvka 算法。

考虑用数据结构找到最短的边,但是一个点 \(u\) 找到的最短边可能会连向相同连通块的 \(v\)

为了规避这种问题,考虑维护最小次小,满足最小次小所在的连通块不同,

这样如果最小值与 \(u\) 在一个连通块里,就直接选次小的即可。

接下来就使用线段树和 ST 表维护一下即可。

时间复杂度 \(O(n \log^2 n)\)

const int N = 1e5 + 5;
const int INF = 1e9 + 7;
int n, m, a[N];
vector<PII> e[N];
int f[N]; PII p[N];
int find(int u) {
	if(f[u] == u) return u;
	return f[u] = find(f[u]);
}
struct Node {
	PII fi, se;
	Node(PII h = {INF, 0}) {
		fi = h; se = {INF, 0};
	}
	void push(PII h) {
		if(h <= fi) {
			if(SE(fi) != SE(h)) se = fi;
			fi = h;
		}
		else if(h <= se && SE(fi) != SE(h)) {
			se = h;
		}
	}
	friend Node operator + (Node A, Node B) {
		A.push(B.fi); A.push(B.se);
		return A;
	}
};
struct SgT {
	int le[N << 2], ri[N << 2];
	Node F[N << 2];
	void build(int u, int l, int r) {
		le[u] = l, ri[u] = r;
		F[u] = Node();
		if(l == r) {
			return;
		}
		int mid = l + r >> 1;
		build(u << 1, l, mid);
		build(u << 1 | 1, mid + 1, r);
	}
	void modify(int u, int l, int r, PII h) {
		if(l <= le[u] && ri[u] <= r) {
			F[u].push(h);
			return;
		}
		int mid = le[u] + ri[u] >> 1;
		if(l <= mid) modify(u << 1, l, r, h);
		if(mid < r) modify(u << 1 | 1, l, r, h);
	}
	Node query(int u, int p) {
		if(le[u] == ri[u]) {
			return F[u];
		}
		int mid = le[u] + ri[u] >> 1;
		if(p <= mid) return F[u] + query(u << 1, p);
		else return F[u] + query(u << 1 | 1, p);
	}
} t;
int lg[N];
Node F[N][19];
void build() {
	FOR(i, 1, n) F[i][0] = Node({a[i], find(i)});
	FOR(j, 1, lg[n]) FOR(i, 1, n - (1 << j) + 1)
		F[i][j] = F[i][j - 1] + F[i + (1 << j - 1)][j - 1];
}
Node query(int l, int r) {
	int len = lg[r - l + 1];
	return F[l][len] + F[r - (1 << len) + 1][len];
}
void solve() {
	cin >> n >> m;
	FOR(i, 1, n) cin >> a[i];
	REP(_, m) {
		int u, l, r;
		cin >> u >> l >> r;
		e[u].push_back({l, r});
	}
	FOR(i, 1, n) f[i] = i;
	FOR(i, 2, n) lg[i] = lg[i >> 1] + 1;
	int cnt = n; ll ans = 0;
	while(cnt > 1) {
		FOR(i, 1, n) p[i] = {INF, 0};
		build();
		t.build(1, 1, n);
		FOR(i, 1, n) for(auto h : e[i]) {
			auto w = query(FI(h), SE(h));
			FI(w.fi) += a[i]; FI(w.se) += a[i];
			if(SE(w.fi) == find(i)) chmin(p[find(i)], w.se);
			else chmin(p[find(i)], w.fi);
			t.modify(1, FI(h), SE(h), {a[i], find(i)});
		}
		FOR(i, 1, n) {
			auto w = t.query(1, i);
			FI(w.fi) += a[i]; FI(w.se) += a[i];
			if(SE(w.fi) == find(i)) chmin(p[find(i)], w.se);
			else chmin(p[find(i)], w.fi);
		}
		FOR(i, 1, n) if(SE(p[i])) {
			int v = SE(p[i]);
			if(find(i) == find(v)) continue;
			f[find(i)] = find(v);
			ans += FI(p[i]);
			cnt --;
		}
	}
	cout << ans << endl;
}

P6362 平面欧几里得最小生成树

平面上有 \(n\) 个点,第 \(i\) 个点坐标为 \((x_i, y_i)\)。连接 \(i, j\) 两点的边权为 \(\sqrt{(x_i - x_j) ^ 2 + (y_i - y_j) ^ 2}\)。求最小生成树的边权之和。

\(n,|x|,|y| \le 10^5\)

注意到是完全图 MST,又不想学神秘科技,所以考虑 Boruvka 算法。

那么现在就需要做到对于一个在 \(c\) 集合里的点 \(u\),找出不在 \(c\) 集合里距离 \(u\) 最近的点 \(v\),也就是近邻。

这个东西可以考虑使用 KDT 解决。

注意,这里不能使用找出最小次小的方法来获取非 \(c\) 集合的点 \(v\),这样复杂度是完全错的。

所以每次对于一个集合 \(c\),直接删掉 \(c\) 集合中的所有点,然后对于每个点找近邻即可。

注意到 Boruvka 对于一个集合只需要找出一条连到别的集合最短的边,所以对于同一个集合,直接延用一个 \(ans\),在 KDT 上遍历的时候看能不能使 \(ans\) 变得更优即可。

接下来是对于复杂度的一些讨论。

首先 Boruvka 遍历 \(O(\log n)\) 次,KDT 上删除添加是 \(O(\log n)\),所以打底 \(O(n \log^2 n)\)

然后是 KDT 找近邻的部分,但是有人说是 \(O(n^2)\) 的,也有人说是 \(O(n\sqrt{n})\) 的。

但是这玩意很难卡,其中一个原因是因为题目中的值域很小且都是整数。还有一点就是 Boruvka 会使点合并,即使第一轮卡很满,后面依然会跑很快,所以不是很能卡掉,再加个随机旋转就基本卡不掉了。

const int N = 1e5 + 5;
const int INF = 1e8 + 7;
const ll LNF = 1e18;
int n, cmp;
struct Point {
	double x[2];
	bool operator < (const Point A) const {
		return x[cmp] < A.x[cmp];
	}
} a[N];
struct KDT {
	double L[2], R[2];
	int ls, rs, p, b;
} t[N]; int rt, tot;
struct Node {
	double w; int c;
	bool operator > (Node & A) const {
		if(w == A.w) return c > A.c;
		return w > A.w;
	}
} p[N];
int f[N];
vector<int> e[N];
inline double sq(double x) {
	return x * x;
}
inline double dist(Point A, Point B) {
	return sq(A.x[0] - B.x[0]) + sq(A.x[1] - B.x[1]);
}
inline int find(int u) {
	if(f[u] == u) return u;
	return f[u] = find(f[u]);
}
inline void pushup(int u) {
	int l = t[u].ls, r = t[u].rs;
	REP(i, 2) {
		if(t[u].b) t[u].L[i] = t[u].R[i] = a[t[u].p].x[i];
		else t[u].L[i] = INF, t[u].R[i] = - INF;
		if(l) {
			chmin(t[u].L[i], t[l].L[i]);
			chmax(t[u].R[i], t[l].R[i]);
		}
		if(r) {
			chmin(t[u].L[i], t[r].L[i]);
			chmax(t[u].R[i], t[r].R[i]);
		}
	}
}
int build(int l, int r, int o) {
	if(l > r) return 0;
	int mid = l + r >> 1;
	int u = ++ tot;
	cmp = o;
	nth_element(a + l, a + mid, a + r + 1);
	t[u].b = 1;
	t[u].p = mid;
	t[u].ls = build(l, mid - 1, o ^ 1);
	t[u].rs = build(mid + 1, r, o ^ 1);
	pushup(u);
	return u;
}
void query(int u, int v) {
	if(! u) return;
	int c = find(v);
	if(t[u].b) chmin(p[c], (Node){dist(a[v], a[t[u].p]), find(t[u].p)});
	int o0 = t[u].L[0] <= a[v].x[0] && a[v].x[0] <= t[u].R[0];
	int o1 = t[u].L[1] <= a[v].x[1] && a[v].x[1] <= t[u].R[1];
	if(! o0 && ! o1) {
		double val = LNF;
		chmin(val, dist(a[v], {t[u].L[0], t[u].L[1]}));
		chmin(val, dist(a[v], {t[u].L[0], t[u].R[1]}));
		chmin(val, dist(a[v], {t[u].R[0], t[u].L[1]}));
		chmin(val, dist(a[v], {t[u].R[0], t[u].R[1]}));
		if(val >= p[c].w) return;
	}
	if(o0 && ! o1) {
		double val = LNF;
		chmin(val, sq(a[v].x[1] - t[u].L[1]));
		chmin(val, sq(a[v].x[1] - t[u].R[1]));
		if(val >= p[c].w) return;
	}
	if(! o0 && o1) {
		double val = LNF;
		chmin(val, sq(a[v].x[0] - t[u].L[0]));
		chmin(val, sq(a[v].x[0] - t[u].R[0]));
		if(val >= p[c].w) return;
	}
	query(t[u].ls, v);
	query(t[u].rs, v);
}
void insert(int u, int l, int r, int q, int x) {
	int mid = l + r >> 1;
	if(mid == q) {
		t[u].b = x;
		pushup(u);
		return;
	}
	if(q < mid) insert(t[u].ls, l, mid - 1, q, x);
	else insert(t[u].rs, mid + 1, r, q, x);
	pushup(u);
}
void solve() {
	cin >> n;
	double alpha = 1.14;
	FOR(i, 1, n) {
		double X, Y;
		cin >> X >> Y;
		a[i].x[0] = X * cos(alpha) - Y * sin(alpha);
		a[i].x[1] = X * sin(alpha) + Y * cos(alpha);
	}
	rt = build(1, n, 0);
	FOR(i, 1, n) f[i] = i;
	int cnt = n;
	double ans = 0;
	while(cnt > 1) {
		FOR(i, 1, n) p[i] = {LNF, 0};
		FOR(i, 1, n) e[i].clear();
		FOR(i, 1, n) e[find(i)].push_back(i);
		FOR(i, 1, n) if(! e[i].empty()) {
			FORV(x, e[i]) insert(rt, 1, n, * x, 0);
			FORV(x, e[i]) query(rt, * x);
			FORV(x, e[i]) insert(rt, 1, n, * x, 1);
		}
		FOR(i, 1, n) if(p[i].c) {
			int u = find(i);
			int v = find(p[i].c);
			if(u == v) continue;
			f[v] = u;
			cnt --;
			ans += sqrt(p[i].w);
		}
	}
	cout << fixed << setprecision(10);
	cout << ans << endl;
}
posted @ 2025-06-25 19:44  KevinLikesCoding  阅读(136)  评论(0)    收藏  举报