10-1 不相交集合
不相交集合(Disjoint Sets / Union-Find)
不相交集合(Disjoint Sets),也称为并查集(Union-Find),是一种维护多个互不相交的集合的数据结构。它支持两种核心操作:
- Find:查找元素属于哪个集合(返回集合的代表元素)
- Union:将两个集合合并为一个集合
这两个操作看似简单,但通过路径压缩(Path Compression)和按秩合并(Union by Rank)两种优化,单次操作的均摊时间复杂度可以降至 O(α(n)),其中 α 是反阿克曼函数(Inverse Ackermann Function),在实际应用中不超过 5,可以视为常数。
不相交集合是图论中 Kruskal 最小生成树算法的核心依赖,也广泛用于连通性判断、等价类划分和动态连通性问题。
基本概念
集合的表示
每个集合用一棵有根树(Rooted Tree)来表示:
- 树的根节点是集合的代表元素(Representative)
- 每个节点有一个
parent指针,指向其父节点 - 根节点的
parent指针指向自己
集合 {0, 1, 2} 的表示: 集合 {3, 4} 的表示:
0 (代表) 3 (代表)
/ \ |
1 2 4
parent[0] = 0 (根节点指向自己)
parent[1] = 0
parent[2] = 0
parent[3] = 3
parent[4] = 3
核心操作
MakeSet(x):创建一个只包含 x 的集合
parent[x] = x
Find(x):找到 x 所在集合的代表元素(根节点)
沿着 parent 指针一直向上走,直到 parent[x] == x
Union(x, y):合并 x 和 y 所在的两个集合
rootX = Find(x)
rootY = Find(y)
parent[rootX] = rootY (将一棵树的根挂在另一棵树的根下面)
朴素实现(Naive Implementation)
最简单的实现不做任何优化。以 5 个元素为例:
初始状态(MakeSet):
parent = [0, 1, 2, 3, 4] (每个元素独立,自成一集)
Union(0, 1):
Find(0) = 0, Find(1) = 1
parent[0] = 1
parent = [1, 1, 2, 3, 4]
Union(2, 3):
Find(2) = 2, Find(3) = 3
parent[2] = 3
parent = [1, 1, 3, 3, 4]
Union(0, 3):
Find(0): 0 → 1 (根) rootX = 1
Find(3): 3 → 3 (根) rootY = 3
parent[1] = 3
parent = [1, 3, 3, 3, 4]
集合状态:
{0, 1, 2, 3} (代表元素为 3)
{4} (代表元素为 4)
Find(0): 0 → 1 → 3 (根) 返回 3
Find(2): 2 → 3 (根) 返回 3
Find(4): 4 → 4 (根) 返回 4
C++ 实现
#include <iostream>
#include <vector>
using namespace std;
class DisjointSets {
vector<int> parent;
public:
DisjointSets(int n) : parent(n) {
for (int i = 0; i < n; i++)
parent[i] = i; // MakeSet: each element is its own parent
}
int find(int x) {
while (parent[x] != x)
x = parent[x];
return x;
}
void unionSets(int x, int y) {
int rootX = find(x);
int rootY = find(y);
if (rootX != rootY)
parent[rootX] = rootY;
}
bool connected(int x, int y) {
return find(x) == find(y);
}
};
int main() {
DisjointSets ds(5);
ds.unionSets(0, 1);
ds.unionSets(2, 3);
ds.unionSets(0, 3);
cout << "Find(0) = " << ds.find(0) << endl;
cout << "Find(2) = " << ds.find(2) << endl;
cout << "Find(4) = " << ds.find(4) << endl;
cout << "Connected(0, 2) = " << (ds.connected(0, 2) ? "true" : "false") << endl;
cout << "Connected(0, 4) = " << (ds.connected(0, 4) ? "true" : "false") << endl;
return 0;
}
C 实现
#include <stdio.h>
#include <stdbool.h>
void make_set(int* parent, int n) {
for (int i = 0; i < n; i++)
parent[i] = i;
}
int find(int* parent, int x) {
while (parent[x] != x)
x = parent[x];
return x;
}
void union_sets(int* parent, int x, int y) {
int rootX = find(parent, x);
int rootY = find(parent, y);
if (rootX != rootY)
parent[rootX] = rootY;
}
bool connected(int* parent, int x, int y) {
return find(parent, x) == find(parent, y);
}
int main() {
int parent[5];
make_set(parent, 5);
union_sets(parent, 0, 1);
union_sets(parent, 2, 3);
union_sets(parent, 0, 3);
printf("Find(0) = %d\n", find(parent, 0));
printf("Find(2) = %d\n", find(parent, 2));
printf("Find(4) = %d\n", find(parent, 4));
printf("Connected(0, 2) = %s\n", connected(parent, 0, 2) ? "true" : "false");
printf("Connected(0, 4) = %s\n", connected(parent, 0, 4) ? "true" : "false");
return 0;
}
Python 实现
class DisjointSets:
def __init__(self, n):
self.parent = list(range(n)) # MakeSet: each element is its own parent
def find(self, x):
while self.parent[x] != x:
x = self.parent[x]
return x
def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x != root_y:
self.parent[root_x] = root_y
def connected(self, x, y):
return self.find(x) == self.find(y)
ds = DisjointSets(5)
ds.union(0, 1)
ds.union(2, 3)
ds.union(0, 3)
print(f"Find(0) = {ds.find(0)}")
print(f"Find(2) = {ds.find(2)}")
print(f"Find(4) = {ds.find(4)}")
print(f"Connected(0, 2) = {ds.connected(0, 2)}")
print(f"Connected(0, 4) = {ds.connected(0, 4)}")
Go 实现
package main
import "fmt"
type DisjointSets struct {
parent []int
}
func NewDisjointSets(n int) *DisjointSets {
parent := make([]int, n)
for i := range parent {
parent[i] = i
}
return &DisjointSets{parent: parent}
}
func (ds *DisjointSets) Find(x int) int {
for ds.parent[x] != x {
x = ds.parent[x]
}
return x
}
func (ds *DisjointSets) Union(x, y int) {
rootX := ds.Find(x)
rootY := ds.Find(y)
if rootX != rootY {
ds.parent[rootX] = rootY
}
}
func (ds *DisjointSets) Connected(x, y int) bool {
return ds.Find(x) == ds.Find(y)
}
func main() {
ds := NewDisjointSets(5)
ds.Union(0, 1)
ds.Union(2, 3)
ds.Union(0, 3)
fmt.Printf("Find(0) = %d\n", ds.Find(0))
fmt.Printf("Find(2) = %d\n", ds.Find(2))
fmt.Printf("Find(4) = %d\n", ds.Find(4))
fmt.Printf("Connected(0, 2) = %v\n", ds.Connected(0, 2))
fmt.Printf("Connected(0, 4) = %v\n", ds.Connected(0, 4))
}
运行该程序将输出:
Find(0) = 3
Find(2) = 3
Find(4) = 4
Connected(0, 2) = true
Connected(0, 4) = false
朴素实现的问题在于:最坏情况下树会退化成链表,Find 操作的时间复杂度为 O(n)。
按秩合并(Union by Rank)
按秩合并(Union by Rank)是一种优化策略:在 Union 时,总是将较矮的树挂到较高的树下面,防止树退化成链表。
每个节点维护一个 rank 值,表示以该节点为根的树的高度上界:
- MakeSet 时,rank = 0
- Union 时,rank 小的根挂到 rank 大的根下面
- 仅当两棵树 rank 相同时,合并后的 rank 才 +1
初始:每个元素 rank = 0
Union(0, 1):rank[0] == rank[1] == 0,挂 0 到 1 下面
rank[1] = 1
Union(2, 3):rank[2] == rank[3] == 0,挂 2 到 3 下面
rank[3] = 1
Union(0, 3):Find(0)=1, Find(3)=3
rank[1] == rank[3] == 1,挂 1 到 3 下面
rank[3] = 2
Union(4, 0):Find(4)=4(rank=0), Find(0)=3(rank=2)
rank[4] < rank[3],挂 4 到 3 下面
rank[3] 不变
最终树结构(始终平衡):
3 (rank=2)
/|\
1 2 4
|
0
与朴素实现对比,同样的操作序列,按秩合并保证了树的高度不超过 O(log n)。
路径压缩(Path Compression)
路径压缩(Path Compression)是一种在 Find 操作中进行的优化:查找根节点的过程中,将路径上的每个节点直接挂到根节点下面。
路径压缩前: 路径压缩后:
3 3
| /|\
1 0 1 2
|
0
|
2
Find(2): 沿 2→0→1→3 找到根 3
同时将路径上的 2, 0, 1 都直接挂到 3 下面
用递归实现路径压缩非常简洁:
Find(x):
if parent[x] != x:
parent[x] = Find(parent[x]) // 递归找根,同时压缩路径
return parent[x]
完整实现:按秩合并 + 路径压缩
将两种优化结合后,m 次操作(Find + Union)在 n 个元素上的均摊时间复杂度为 O(m · α(n)),其中 α 是反阿克曼函数,在实际中不超过 4。
以一个连通性问题为例:判断无向图中的节点是否连通。
图的边:(0,1), (1,2), (3,4), (2,3)
初始:5 个独立集合
处理 (0,1):Union(0,1) → {0,1} {2} {3} {4}
处理 (1,2):Union(1,2) → {0,1,2} {3} {4}
处理 (3,4):Union(3,4) → {0,1,2} {3,4}
处理 (2,3):Union(2,3) → {0,1,2,3,4}
所有节点连通!
C++ 实现
#include <iostream>
#include <vector>
using namespace std;
class DisjointSets {
vector<int> parent;
vector<int> rank_;
public:
DisjointSets(int n) : parent(n), rank_(n, 0) {
for (int i = 0; i < n; i++)
parent[i] = i;
}
// Find with path compression
int find(int x) {
if (parent[x] != x)
parent[x] = find(parent[x]); // Compress path
return parent[x];
}
// Union by rank
void unionSets(int x, int y) {
int rootX = find(x);
int rootY = find(y);
if (rootX == rootY) return;
if (rank_[rootX] < rank_[rootY]) {
parent[rootX] = rootY;
} else if (rank_[rootX] > rank_[rootY]) {
parent[rootY] = rootX;
} else {
parent[rootY] = rootX;
rank_[rootX]++;
}
}
bool connected(int x, int y) {
return find(x) == find(y);
}
};
int main() {
// Graph edges
int edges[][2] = {{0,1}, {1,2}, {3,4}, {2,3}};
int n = 5;
DisjointSets ds(n);
cout << "Processing edges:" << endl;
for (auto& e : edges) {
ds.unionSets(e[0], e[1]);
cout << " Union(" << e[0] << ", " << e[1] << ")"
<< " -> Connected: " << (ds.connected(0, 4) ? "yes" : "no") << endl;
}
cout << "\nFinal connectivity:" << endl;
cout << " 0-4: " << (ds.connected(0, 4) ? "connected" : "not connected") << endl;
cout << " 1-3: " << (ds.connected(1, 3) ? "connected" : "not connected") << endl;
return 0;
}
C 实现
#include <stdio.h>
#include <stdbool.h>
typedef struct {
int* parent;
int* rank;
int n;
} DisjointSets;
void ds_init(DisjointSets* ds, int n) {
ds->n = n;
ds->parent = (int*)malloc(n * sizeof(int));
ds->rank = (int*)calloc(n, sizeof(int));
for (int i = 0; i < n; i++)
ds->parent[i] = i;
}
void ds_free(DisjointSets* ds) {
free(ds->parent);
free(ds->rank);
}
// Find with path compression
int ds_find(DisjointSets* ds, int x) {
if (ds->parent[x] != x)
ds->parent[x] = ds_find(ds, ds->parent[x]);
return ds->parent[x];
}
// Union by rank
void ds_union(DisjointSets* ds, int x, int y) {
int rootX = ds_find(ds, x);
int rootY = ds_find(ds, y);
if (rootX == rootY) return;
if (ds->rank[rootX] < ds->rank[rootY]) {
ds->parent[rootX] = rootY;
} else if (ds->rank[rootX] > ds->rank[rootY]) {
ds->parent[rootY] = rootX;
} else {
ds->parent[rootY] = rootX;
ds->rank[rootX]++;
}
}
bool ds_connected(DisjointSets* ds, int x, int y) {
return ds_find(ds, x) == ds_find(ds, y);
}
int main() {
int edges[][2] = {{0,1}, {1,2}, {3,4}, {2,3}};
int n = 5;
DisjointSets ds;
ds_init(&ds, n);
printf("Processing edges:\n");
for (int i = 0; i < 4; i++) {
ds_union(&ds, edges[i][0], edges[i][1]);
printf(" Union(%d, %d) -> Connected 0-4: %s\n",
edges[i][0], edges[i][1],
ds_connected(&ds, 0, 4) ? "yes" : "no");
}
printf("\nFinal connectivity:\n");
printf(" 0-4: %s\n", ds_connected(&ds, 0, 4) ? "connected" : "not connected");
printf(" 1-3: %s\n", ds_connected(&ds, 1, 3) ? "connected" : "not connected");
ds_free(&ds);
return 0;
}
Python 实现
class DisjointSets:
"""Union-Find with path compression and union by rank."""
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # Path compression
return self.parent[x]
def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return
# Union by rank
if self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
elif self.rank[root_x] > self.rank[root_y]:
self.parent[root_y] = root_x
else:
self.parent[root_y] = root_x
self.rank[root_x] += 1
def connected(self, x, y):
return self.find(x) == self.find(y)
edges = [(0, 1), (1, 2), (3, 4), (2, 3)]
n = 5
ds = DisjointSets(n)
print("Processing edges:")
for u, v in edges:
ds.union(u, v)
print(f" Union({u}, {v}) -> Connected 0-4: "
f"{'yes' if ds.connected(0, 4) else 'no'}")
print("\nFinal connectivity:")
print(f" 0-4: {'connected' if ds.connected(0, 4) else 'not connected'}")
print(f" 1-3: {'connected' if ds.connected(1, 3) else 'not connected'}")
Go 实现
package main
import "fmt"
type DisjointSets struct {
parent []int
rank []int
}
func NewDisjointSets(n int) *DisjointSets {
parent := make([]int, n)
rank := make([]int, n)
for i := range parent {
parent[i] = i
}
return &DisjointSets{parent: parent, rank: rank}
}
// Find with path compression
func (ds *DisjointSets) Find(x int) int {
if ds.parent[x] != x {
ds.parent[x] = ds.Find(ds.parent[x])
}
return ds.parent[x]
}
// Union by rank
func (ds *DisjointSets) Union(x, y int) {
rootX := ds.Find(x)
rootY := ds.Find(y)
if rootX == rootY {
return
}
if ds.rank[rootX] < ds.rank[rootY] {
ds.parent[rootX] = rootY
} else if ds.rank[rootX] > ds.rank[rootY] {
ds.parent[rootY] = rootX
} else {
ds.parent[rootY] = rootX
ds.rank[rootX]++
}
}
func (ds *DisjointSets) Connected(x, y int) bool {
return ds.Find(x) == ds.Find(y)
}
func main() {
edges := [][2]int{{0, 1}, {1, 2}, {3, 4}, {2, 3}}
n := 5
ds := NewDisjointSets(n)
fmt.Println("Processing edges:")
for _, e := range edges {
ds.Union(e[0], e[1])
connected := "no"
if ds.Connected(0, 4) {
connected = "yes"
}
fmt.Printf(" Union(%d, %d) -> Connected 0-4: %s\n", e[0], e[1], connected)
}
fmt.Println("\nFinal connectivity:")
fmt.Printf(" 0-4: %s\n", map[bool]string{true: "connected", false: "not connected"}[ds.Connected(0, 4)])
fmt.Printf(" 1-3: %s\n", map[bool]string{true: "connected", false: "not connected"}[ds.Connected(1, 3)])
}
运行该程序将输出:
Processing edges:
Union(0, 1) -> Connected 0-4: no
Union(1, 2) -> Connected 0-4: no
Union(3, 4) -> Connected 0-4: no
Union(2, 3) -> Connected 0-4: yes
Final connectivity:
0-4: connected
1-3: connected
应用:Kruskal 最小生成树
不相交集合最经典的应用是 Kruskal 算法(Kruskal's Algorithm),用于求带权无向图的最小生成树(Minimum Spanning Tree, MST)。
算法步骤:
- 将所有边按权重从小到大排序
- 依次取边,如果边的两个端点不在同一集合中,则加入该边并合并集合
- 重复直到选了 n-1 条边
图:4 个节点,5 条边
0-1: weight 1
0-2: weight 3
1-2: weight 2
1-3: weight 4
2-3: weight 5
排序后的边:(0,1,1), (1,2,2), (0,2,3), (1,3,4), (2,3,5)
Step 1: 边 (0,1,1) — Find(0)≠Find(1) → 加入, Union(0,1)
Step 2: 边 (1,2,2) — Find(1)≠Find(2) → 加入, Union(1,2)
Step 3: 边 (0,2,3) — Find(0)==Find(2) → 跳过 (会形成环)
Step 4: 边 (1,3,4) — Find(1)≠Find(3) → 加入, Union(1,3)
最小生成树: (0,1,1) + (1,2,2) + (1,3,4) = 7
C++ 实现
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
struct Edge {
int u, v, weight;
bool operator<(const Edge& other) const {
return weight < other.weight;
}
};
class DisjointSets {
vector<int> parent;
vector<int> rank_;
public:
DisjointSets(int n) : parent(n), rank_(n, 0) {
for (int i = 0; i < n; i++)
parent[i] = i;
}
int find(int x) {
if (parent[x] != x)
parent[x] = find(parent[x]);
return parent[x];
}
void unionSets(int x, int y) {
int rx = find(x), ry = find(y);
if (rx == ry) return;
if (rank_[rx] < rank_[ry]) parent[rx] = ry;
else if (rank_[rx] > rank_[ry]) parent[ry] = rx;
else { parent[ry] = rx; rank_[rx]++; }
}
bool connected(int x, int y) { return find(x) == find(y); }
};
int kruskal(int n, vector<Edge>& edges) {
sort(edges.begin(), edges.end());
DisjointSets ds(n);
int mstWeight = 0;
int edgeCount = 0;
for (auto& e : edges) {
if (!ds.connected(e.u, e.v)) {
ds.unionSets(e.u, e.v);
mstWeight += e.weight;
edgeCount++;
cout << " Edge (" << e.u << ", " << e.v << ") weight="
<< e.weight << " -> MST total: " << mstWeight << endl;
if (edgeCount == n - 1) break;
} else {
cout << " Edge (" << e.u << ", " << e.v << ") weight="
<< e.weight << " -> skipped (cycle)" << endl;
}
}
return mstWeight;
}
int main() {
int n = 4;
vector<Edge> edges = {
{0, 1, 1}, {0, 2, 3}, {1, 2, 2}, {1, 3, 4}, {2, 3, 5}
};
cout << "Kruskal's MST:" << endl;
int total = kruskal(n, edges);
cout << "Total MST weight: " << total << endl;
return 0;
}
C 实现
#include <stdio.h>
#include <stdlib.h>
typedef struct {
int u, v, weight;
} Edge;
typedef struct {
int* parent;
int* rank;
int n;
} DisjointSets;
void ds_init(DisjointSets* ds, int n) {
ds->n = n;
ds->parent = (int*)malloc(n * sizeof(int));
ds->rank = (int*)calloc(n, sizeof(int));
for (int i = 0; i < n; i++)
ds->parent[i] = i;
}
void ds_free(DisjointSets* ds) {
free(ds->parent);
free(ds->rank);
}
int ds_find(DisjointSets* ds, int x) {
if (ds->parent[x] != x)
ds->parent[x] = ds_find(ds, ds->parent[x]);
return ds->parent[x];
}
void ds_union(DisjointSets* ds, int x, int y) {
int rx = ds_find(ds, x), ry = ds_find(ds, y);
if (rx == ry) return;
if (ds->rank[rx] < ds->rank[ry]) ds->parent[rx] = ry;
else if (ds->rank[rx] > ds->rank[ry]) ds->parent[ry] = rx;
else { ds->parent[ry] = rx; ds->rank[rx]++; }
}
int ds_connected(DisjointSets* ds, int x, int y) {
return ds_find(ds, x) == ds_find(ds, y);
}
int cmp_edge(const void* a, const void* b) {
return ((Edge*)a)->weight - ((Edge*)b)->weight;
}
int kruskal(int n, Edge* edges, int edgeCount) {
qsort(edges, edgeCount, sizeof(Edge), cmp_edge);
DisjointSets ds;
ds_init(&ds, n);
int mstWeight = 0;
int selected = 0;
for (int i = 0; i < edgeCount && selected < n - 1; i++) {
if (!ds_connected(&ds, edges[i].u, edges[i].v)) {
ds_union(&ds, edges[i].u, edges[i].v);
mstWeight += edges[i].weight;
selected++;
printf(" Edge (%d, %d) weight=%d -> MST total: %d\n",
edges[i].u, edges[i].v, edges[i].weight, mstWeight);
} else {
printf(" Edge (%d, %d) weight=%d -> skipped (cycle)\n",
edges[i].u, edges[i].v, edges[i].weight);
}
}
ds_free(&ds);
return mstWeight;
}
int main() {
Edge edges[] = {
{0, 1, 1}, {0, 2, 3}, {1, 2, 2}, {1, 3, 4}, {2, 3, 5}
};
int n = 4;
int edgeCount = sizeof(edges) / sizeof(edges[0]);
printf("Kruskal's MST:\n");
int total = kruskal(n, edges, edgeCount);
printf("Total MST weight: %d\n", total);
return 0;
}
Python 实现
class DisjointSets:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
rx, ry = self.find(x), self.find(y)
if rx == ry:
return
if self.rank[rx] < self.rank[ry]:
self.parent[rx] = ry
elif self.rank[rx] > self.rank[ry]:
self.parent[ry] = rx
else:
self.parent[ry] = rx
self.rank[rx] += 1
def connected(self, x, y):
return self.find(x) == self.find(y)
def kruskal(n, edges):
"""Kruskal's MST algorithm. edges: list of (u, v, weight)."""
edges = sorted(edges, key=lambda e: e[2])
ds = DisjointSets(n)
mst_weight = 0
edge_count = 0
for u, v, w in edges:
if not ds.connected(u, v):
ds.union(u, v)
mst_weight += w
edge_count += 1
print(f" Edge ({u}, {v}) weight={w} -> MST total: {mst_weight}")
if edge_count == n - 1:
break
else:
print(f" Edge ({u}, {v}) weight={w} -> skipped (cycle)")
return mst_weight
edges = [(0, 1, 1), (0, 2, 3), (1, 2, 2), (1, 3, 4), (2, 3, 5)]
n = 4
print("Kruskal's MST:")
total = kruskal(n, edges)
print(f"Total MST weight: {total}")
Go 实现
package main
import (
"fmt"
"sort"
)
type Edge struct {
U, V, Weight int
}
type DisjointSets struct {
parent []int
rank []int
}
func NewDisjointSets(n int) *DisjointSets {
parent := make([]int, n)
rank := make([]int, n)
for i := range parent {
parent[i] = i
}
return &DisjointSets{parent: parent, rank: rank}
}
func (ds *DisjointSets) Find(x int) int {
if ds.parent[x] != x {
ds.parent[x] = ds.Find(ds.parent[x])
}
return ds.parent[x]
}
func (ds *DisjointSets) Union(x, y int) {
rx, ry := ds.Find(x), ds.Find(y)
if rx == ry {
return
}
if ds.rank[rx] < ds.rank[ry] {
ds.parent[rx] = ry
} else if ds.rank[rx] > ds.rank[ry] {
ds.parent[ry] = rx
} else {
ds.parent[ry] = rx
ds.rank[rx]++
}
}
func (ds *DisjointSets) Connected(x, y int) bool {
return ds.Find(x) == ds.Find(y)
}
func kruskal(n int, edges []Edge) int {
sort.Slice(edges, func(i, j int) bool {
return edges[i].Weight < edges[j].Weight
})
ds := NewDisjointSets(n)
mstWeight := 0
edgeCount := 0
for _, e := range edges {
if !ds.Connected(e.U, e.V) {
ds.Union(e.U, e.V)
mstWeight += e.Weight
edgeCount++
fmt.Printf(" Edge (%d, %d) weight=%d -> MST total: %d\n",
e.U, e.V, e.Weight, mstWeight)
if edgeCount == n-1 {
break
}
} else {
fmt.Printf(" Edge (%d, %d) weight=%d -> skipped (cycle)\n",
e.U, e.V, e.Weight)
}
}
return mstWeight
}
func main() {
edges := []Edge{
{0, 1, 1}, {0, 2, 3}, {1, 2, 2}, {1, 3, 4}, {2, 3, 5},
}
n := 4
fmt.Println("Kruskal's MST:")
total := kruskal(n, edges)
fmt.Printf("Total MST weight: %d\n", total)
}
运行该程序将输出:
Kruskal's MST:
Edge (0, 1) weight=1 -> MST total: 1
Edge (1, 2) weight=2 -> MST total: 3
Edge (0, 2) weight=3 -> skipped (cycle)
Edge (1, 3) weight=4 -> MST total: 7
Total MST weight: 7
不相交集合的性质
复杂度
| 实现方式 | Find | Union | m 次操作 |
|---|---|---|---|
| 朴素 | O(n) 最坏 | O(n) 最坏 | O(m·n) |
| 按秩合并 | O(log n) | O(log n) | O(m·log n) |
| 路径压缩 | O(log n) 均摊 | O(log n) 均摊 | O(m·log n) |
| 按秩合并 + 路径压缩 | O(α(n)) | O(α(n)) | O(m·α(n)) |
其中 α(n) 是反阿克曼函数,增长极其缓慢:
α(n) 的值:
n ≤ 2 → α(n) = 1
n ≤ 4 → α(n) = 2
n ≤ 16 → α(n) = 3
n ≤ 2^65536 → α(n) = 4
n > 2^65536 → α(n) ≥ 5 (实际不可能遇到)
在实际应用中,α(n) 可以视为常数 4,因此 m 次操作的总时间几乎就是 O(m)。
按秩合并 vs 按大小合并
| 策略 | 含义 | 树高度上界 |
|---|---|---|
| 按秩合并(Union by Rank) | rank 是树高度的上界估计 | O(log n) |
| 按大小合并(Union by Size) | 总是将小集挂到大集下面 | O(log n) |
两种策略的效果相同,都保证树的高度不超过 O(log n)。按大小合并更直观,按秩合并在路径压缩下更高效。
关键性质
| 性质 | 说明 |
|---|---|
| 不相交性 | 任何时刻,每个元素恰好属于一个集合 |
| 代表元素唯一 | 同一集合中所有元素的 Find 结果相同 |
| 不可分割 | Union 后两个集合永久合并,无法拆分 |
| 路径压缩不改变集合 | 只改变树的结构,不改变集合的划分 |
应用场景
| 应用 | 说明 |
|---|---|
| Kruskal 最小生成树 | 判断加边是否形成环 |
| 动态连通性 | 判断图中两点是否连通 |
| 等价类划分 | 将等价元素归为同一集合 |
| 图像处理 | 连通区域标记(Connected Component Labeling) |
| 网络协议 | 检测网络中的冗余连接 |
| 最近公共祖先(LCA) | Tarjan 离线 LCA 算法的核心数据结构 |
| 并行计算 | 动态处理器分配 |

浙公网安备 33010602011771号