10-1 不相交集合

不相交集合(Disjoint Sets / Union-Find)

不相交集合(Disjoint Sets),也称为并查集(Union-Find),是一种维护多个互不相交的集合的数据结构。它支持两种核心操作:

  • Find:查找元素属于哪个集合(返回集合的代表元素)
  • Union:将两个集合合并为一个集合

这两个操作看似简单,但通过路径压缩(Path Compression)和按秩合并(Union by Rank)两种优化,单次操作的均摊时间复杂度可以降至 O(α(n)),其中 α 是反阿克曼函数(Inverse Ackermann Function),在实际应用中不超过 5,可以视为常数。

不相交集合是图论中 Kruskal 最小生成树算法的核心依赖,也广泛用于连通性判断、等价类划分和动态连通性问题。


基本概念

集合的表示

每个集合用一棵有根树(Rooted Tree)来表示:

  • 树的根节点是集合的代表元素(Representative)
  • 每个节点有一个 parent 指针,指向其父节点
  • 根节点的 parent 指针指向自己
集合 {0, 1, 2} 的表示:        集合 {3, 4} 的表示:

    0 (代表)                     3 (代表)
   / \                            |
  1   2                           4

parent[0] = 0  (根节点指向自己)
parent[1] = 0
parent[2] = 0
parent[3] = 3
parent[4] = 3

核心操作

MakeSet(x):创建一个只包含 x 的集合
  parent[x] = x

Find(x):找到 x 所在集合的代表元素(根节点)
  沿着 parent 指针一直向上走,直到 parent[x] == x

Union(x, y):合并 x 和 y 所在的两个集合
  rootX = Find(x)
  rootY = Find(y)
  parent[rootX] = rootY  (将一棵树的根挂在另一棵树的根下面)

朴素实现(Naive Implementation)

最简单的实现不做任何优化。以 5 个元素为例:

初始状态(MakeSet):
parent = [0, 1, 2, 3, 4]  (每个元素独立,自成一集)

Union(0, 1):
  Find(0) = 0, Find(1) = 1
  parent[0] = 1
  parent = [1, 1, 2, 3, 4]

Union(2, 3):
  Find(2) = 2, Find(3) = 3
  parent[2] = 3
  parent = [1, 1, 3, 3, 4]

Union(0, 3):
  Find(0): 0 → 1 (根)    rootX = 1
  Find(3): 3 → 3 (根)    rootY = 3
  parent[1] = 3
  parent = [1, 3, 3, 3, 4]

集合状态:
  {0, 1, 2, 3}  (代表元素为 3)
  {4}           (代表元素为 4)

Find(0): 0 → 1 → 3 (根)   返回 3
Find(2): 2 → 3 (根)        返回 3
Find(4): 4 → 4 (根)        返回 4

C++ 实现

#include <iostream>
#include <vector>
using namespace std;

class DisjointSets {
    vector<int> parent;

public:
    DisjointSets(int n) : parent(n) {
        for (int i = 0; i < n; i++)
            parent[i] = i;  // MakeSet: each element is its own parent
    }

    int find(int x) {
        while (parent[x] != x)
            x = parent[x];
        return x;
    }

    void unionSets(int x, int y) {
        int rootX = find(x);
        int rootY = find(y);
        if (rootX != rootY)
            parent[rootX] = rootY;
    }

    bool connected(int x, int y) {
        return find(x) == find(y);
    }
};

int main() {
    DisjointSets ds(5);

    ds.unionSets(0, 1);
    ds.unionSets(2, 3);
    ds.unionSets(0, 3);

    cout << "Find(0) = " << ds.find(0) << endl;
    cout << "Find(2) = " << ds.find(2) << endl;
    cout << "Find(4) = " << ds.find(4) << endl;
    cout << "Connected(0, 2) = " << (ds.connected(0, 2) ? "true" : "false") << endl;
    cout << "Connected(0, 4) = " << (ds.connected(0, 4) ? "true" : "false") << endl;

    return 0;
}

C 实现

#include <stdio.h>
#include <stdbool.h>

void make_set(int* parent, int n) {
    for (int i = 0; i < n; i++)
        parent[i] = i;
}

int find(int* parent, int x) {
    while (parent[x] != x)
        x = parent[x];
    return x;
}

void union_sets(int* parent, int x, int y) {
    int rootX = find(parent, x);
    int rootY = find(parent, y);
    if (rootX != rootY)
        parent[rootX] = rootY;
}

bool connected(int* parent, int x, int y) {
    return find(parent, x) == find(parent, y);
}

int main() {
    int parent[5];
    make_set(parent, 5);

    union_sets(parent, 0, 1);
    union_sets(parent, 2, 3);
    union_sets(parent, 0, 3);

    printf("Find(0) = %d\n", find(parent, 0));
    printf("Find(2) = %d\n", find(parent, 2));
    printf("Find(4) = %d\n", find(parent, 4));
    printf("Connected(0, 2) = %s\n", connected(parent, 0, 2) ? "true" : "false");
    printf("Connected(0, 4) = %s\n", connected(parent, 0, 4) ? "true" : "false");

    return 0;
}

Python 实现

class DisjointSets:
    def __init__(self, n):
        self.parent = list(range(n))  # MakeSet: each element is its own parent

    def find(self, x):
        while self.parent[x] != x:
            x = self.parent[x]
        return x

    def union(self, x, y):
        root_x = self.find(x)
        root_y = self.find(y)
        if root_x != root_y:
            self.parent[root_x] = root_y

    def connected(self, x, y):
        return self.find(x) == self.find(y)


ds = DisjointSets(5)

ds.union(0, 1)
ds.union(2, 3)
ds.union(0, 3)

print(f"Find(0) = {ds.find(0)}")
print(f"Find(2) = {ds.find(2)}")
print(f"Find(4) = {ds.find(4)}")
print(f"Connected(0, 2) = {ds.connected(0, 2)}")
print(f"Connected(0, 4) = {ds.connected(0, 4)}")

Go 实现

package main

import "fmt"

type DisjointSets struct {
	parent []int
}

func NewDisjointSets(n int) *DisjointSets {
	parent := make([]int, n)
	for i := range parent {
		parent[i] = i
	}
	return &DisjointSets{parent: parent}
}

func (ds *DisjointSets) Find(x int) int {
	for ds.parent[x] != x {
		x = ds.parent[x]
	}
	return x
}

func (ds *DisjointSets) Union(x, y int) {
	rootX := ds.Find(x)
	rootY := ds.Find(y)
	if rootX != rootY {
		ds.parent[rootX] = rootY
	}
}

func (ds *DisjointSets) Connected(x, y int) bool {
	return ds.Find(x) == ds.Find(y)
}

func main() {
	ds := NewDisjointSets(5)

	ds.Union(0, 1)
	ds.Union(2, 3)
	ds.Union(0, 3)

	fmt.Printf("Find(0) = %d\n", ds.Find(0))
	fmt.Printf("Find(2) = %d\n", ds.Find(2))
	fmt.Printf("Find(4) = %d\n", ds.Find(4))
	fmt.Printf("Connected(0, 2) = %v\n", ds.Connected(0, 2))
	fmt.Printf("Connected(0, 4) = %v\n", ds.Connected(0, 4))
}

运行该程序将输出:

Find(0) = 3
Find(2) = 3
Find(4) = 4
Connected(0, 2) = true
Connected(0, 4) = false

朴素实现的问题在于:最坏情况下树会退化成链表,Find 操作的时间复杂度为 O(n)。


按秩合并(Union by Rank)

按秩合并(Union by Rank)是一种优化策略:在 Union 时,总是将较矮的树挂到较高的树下面,防止树退化成链表。

每个节点维护一个 rank 值,表示以该节点为根的树的高度上界:

  • MakeSet 时,rank = 0
  • Union 时,rank 小的根挂到 rank 大的根下面
  • 仅当两棵树 rank 相同时,合并后的 rank 才 +1
初始:每个元素 rank = 0

Union(0, 1):rank[0] == rank[1] == 0,挂 0 到 1 下面
  rank[1] = 1

Union(2, 3):rank[2] == rank[3] == 0,挂 2 到 3 下面
  rank[3] = 1

Union(0, 3):Find(0)=1, Find(3)=3
  rank[1] == rank[3] == 1,挂 1 到 3 下面
  rank[3] = 2

Union(4, 0):Find(4)=4(rank=0), Find(0)=3(rank=2)
  rank[4] < rank[3],挂 4 到 3 下面
  rank[3] 不变

最终树结构(始终平衡):

        3 (rank=2)
       /|\
      1 2 4
      |
      0

与朴素实现对比,同样的操作序列,按秩合并保证了树的高度不超过 O(log n)。


路径压缩(Path Compression)

路径压缩(Path Compression)是一种在 Find 操作中进行的优化:查找根节点的过程中,将路径上的每个节点直接挂到根节点下面。

路径压缩前:             路径压缩后:

    3                       3
    |                      /|\
    1                     0 1 2
    |
    0
    |
    2

Find(2): 沿 2→0→1→3 找到根 3
  同时将路径上的 2, 0, 1 都直接挂到 3 下面

用递归实现路径压缩非常简洁:

Find(x):
  if parent[x] != x:
    parent[x] = Find(parent[x])  // 递归找根,同时压缩路径
  return parent[x]

完整实现:按秩合并 + 路径压缩

将两种优化结合后,m 次操作(Find + Union)在 n 个元素上的均摊时间复杂度为 O(m · α(n)),其中 α 是反阿克曼函数,在实际中不超过 4。

以一个连通性问题为例:判断无向图中的节点是否连通。

图的边:(0,1), (1,2), (3,4), (2,3)

初始:5 个独立集合

处理 (0,1):Union(0,1) → {0,1} {2} {3} {4}
处理 (1,2):Union(1,2) → {0,1,2} {3} {4}
处理 (3,4):Union(3,4) → {0,1,2} {3,4}
处理 (2,3):Union(2,3) → {0,1,2,3,4}

所有节点连通!

C++ 实现

#include <iostream>
#include <vector>
using namespace std;

class DisjointSets {
    vector<int> parent;
    vector<int> rank_;

public:
    DisjointSets(int n) : parent(n), rank_(n, 0) {
        for (int i = 0; i < n; i++)
            parent[i] = i;
    }

    // Find with path compression
    int find(int x) {
        if (parent[x] != x)
            parent[x] = find(parent[x]);  // Compress path
        return parent[x];
    }

    // Union by rank
    void unionSets(int x, int y) {
        int rootX = find(x);
        int rootY = find(y);
        if (rootX == rootY) return;

        if (rank_[rootX] < rank_[rootY]) {
            parent[rootX] = rootY;
        } else if (rank_[rootX] > rank_[rootY]) {
            parent[rootY] = rootX;
        } else {
            parent[rootY] = rootX;
            rank_[rootX]++;
        }
    }

    bool connected(int x, int y) {
        return find(x) == find(y);
    }
};

int main() {
    // Graph edges
    int edges[][2] = {{0,1}, {1,2}, {3,4}, {2,3}};
    int n = 5;

    DisjointSets ds(n);

    cout << "Processing edges:" << endl;
    for (auto& e : edges) {
        ds.unionSets(e[0], e[1]);
        cout << "  Union(" << e[0] << ", " << e[1] << ")"
             << " -> Connected: " << (ds.connected(0, 4) ? "yes" : "no") << endl;
    }

    cout << "\nFinal connectivity:" << endl;
    cout << "  0-4: " << (ds.connected(0, 4) ? "connected" : "not connected") << endl;
    cout << "  1-3: " << (ds.connected(1, 3) ? "connected" : "not connected") << endl;

    return 0;
}

C 实现

#include <stdio.h>
#include <stdbool.h>

typedef struct {
    int* parent;
    int* rank;
    int n;
} DisjointSets;

void ds_init(DisjointSets* ds, int n) {
    ds->n = n;
    ds->parent = (int*)malloc(n * sizeof(int));
    ds->rank = (int*)calloc(n, sizeof(int));
    for (int i = 0; i < n; i++)
        ds->parent[i] = i;
}

void ds_free(DisjointSets* ds) {
    free(ds->parent);
    free(ds->rank);
}

// Find with path compression
int ds_find(DisjointSets* ds, int x) {
    if (ds->parent[x] != x)
        ds->parent[x] = ds_find(ds, ds->parent[x]);
    return ds->parent[x];
}

// Union by rank
void ds_union(DisjointSets* ds, int x, int y) {
    int rootX = ds_find(ds, x);
    int rootY = ds_find(ds, y);
    if (rootX == rootY) return;

    if (ds->rank[rootX] < ds->rank[rootY]) {
        ds->parent[rootX] = rootY;
    } else if (ds->rank[rootX] > ds->rank[rootY]) {
        ds->parent[rootY] = rootX;
    } else {
        ds->parent[rootY] = rootX;
        ds->rank[rootX]++;
    }
}

bool ds_connected(DisjointSets* ds, int x, int y) {
    return ds_find(ds, x) == ds_find(ds, y);
}

int main() {
    int edges[][2] = {{0,1}, {1,2}, {3,4}, {2,3}};
    int n = 5;

    DisjointSets ds;
    ds_init(&ds, n);

    printf("Processing edges:\n");
    for (int i = 0; i < 4; i++) {
        ds_union(&ds, edges[i][0], edges[i][1]);
        printf("  Union(%d, %d) -> Connected 0-4: %s\n",
               edges[i][0], edges[i][1],
               ds_connected(&ds, 0, 4) ? "yes" : "no");
    }

    printf("\nFinal connectivity:\n");
    printf("  0-4: %s\n", ds_connected(&ds, 0, 4) ? "connected" : "not connected");
    printf("  1-3: %s\n", ds_connected(&ds, 1, 3) ? "connected" : "not connected");

    ds_free(&ds);
    return 0;
}

Python 实现

class DisjointSets:
    """Union-Find with path compression and union by rank."""

    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n

    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])  # Path compression
        return self.parent[x]

    def union(self, x, y):
        root_x = self.find(x)
        root_y = self.find(y)
        if root_x == root_y:
            return

        # Union by rank
        if self.rank[root_x] < self.rank[root_y]:
            self.parent[root_x] = root_y
        elif self.rank[root_x] > self.rank[root_y]:
            self.parent[root_y] = root_x
        else:
            self.parent[root_y] = root_x
            self.rank[root_x] += 1

    def connected(self, x, y):
        return self.find(x) == self.find(y)


edges = [(0, 1), (1, 2), (3, 4), (2, 3)]
n = 5

ds = DisjointSets(n)

print("Processing edges:")
for u, v in edges:
    ds.union(u, v)
    print(f"  Union({u}, {v}) -> Connected 0-4: "
          f"{'yes' if ds.connected(0, 4) else 'no'}")

print("\nFinal connectivity:")
print(f"  0-4: {'connected' if ds.connected(0, 4) else 'not connected'}")
print(f"  1-3: {'connected' if ds.connected(1, 3) else 'not connected'}")

Go 实现

package main

import "fmt"

type DisjointSets struct {
	parent []int
	rank   []int
}

func NewDisjointSets(n int) *DisjointSets {
	parent := make([]int, n)
	rank := make([]int, n)
	for i := range parent {
		parent[i] = i
	}
	return &DisjointSets{parent: parent, rank: rank}
}

// Find with path compression
func (ds *DisjointSets) Find(x int) int {
	if ds.parent[x] != x {
		ds.parent[x] = ds.Find(ds.parent[x])
	}
	return ds.parent[x]
}

// Union by rank
func (ds *DisjointSets) Union(x, y int) {
	rootX := ds.Find(x)
	rootY := ds.Find(y)
	if rootX == rootY {
		return
	}

	if ds.rank[rootX] < ds.rank[rootY] {
		ds.parent[rootX] = rootY
	} else if ds.rank[rootX] > ds.rank[rootY] {
		ds.parent[rootY] = rootX
	} else {
		ds.parent[rootY] = rootX
		ds.rank[rootX]++
	}
}

func (ds *DisjointSets) Connected(x, y int) bool {
	return ds.Find(x) == ds.Find(y)
}

func main() {
	edges := [][2]int{{0, 1}, {1, 2}, {3, 4}, {2, 3}}
	n := 5

	ds := NewDisjointSets(n)

	fmt.Println("Processing edges:")
	for _, e := range edges {
		ds.Union(e[0], e[1])
		connected := "no"
		if ds.Connected(0, 4) {
			connected = "yes"
		}
		fmt.Printf("  Union(%d, %d) -> Connected 0-4: %s\n", e[0], e[1], connected)
	}

	fmt.Println("\nFinal connectivity:")
	fmt.Printf("  0-4: %s\n", map[bool]string{true: "connected", false: "not connected"}[ds.Connected(0, 4)])
	fmt.Printf("  1-3: %s\n", map[bool]string{true: "connected", false: "not connected"}[ds.Connected(1, 3)])
}

运行该程序将输出:

Processing edges:
  Union(0, 1) -> Connected 0-4: no
  Union(1, 2) -> Connected 0-4: no
  Union(3, 4) -> Connected 0-4: no
  Union(2, 3) -> Connected 0-4: yes

Final connectivity:
  0-4: connected
  1-3: connected

应用:Kruskal 最小生成树

不相交集合最经典的应用是 Kruskal 算法(Kruskal's Algorithm),用于求带权无向图的最小生成树(Minimum Spanning Tree, MST)。

算法步骤:

  1. 将所有边按权重从小到大排序
  2. 依次取边,如果边的两个端点不在同一集合中,则加入该边并合并集合
  3. 重复直到选了 n-1 条边
图:4 个节点,5 条边
  0-1: weight 1
  0-2: weight 3
  1-2: weight 2
  1-3: weight 4
  2-3: weight 5

排序后的边:(0,1,1), (1,2,2), (0,2,3), (1,3,4), (2,3,5)

Step 1: 边 (0,1,1) — Find(0)≠Find(1) → 加入, Union(0,1)
Step 2: 边 (1,2,2) — Find(1)≠Find(2) → 加入, Union(1,2)
Step 3: 边 (0,2,3) — Find(0)==Find(2) → 跳过 (会形成环)
Step 4: 边 (1,3,4) — Find(1)≠Find(3) → 加入, Union(1,3)

最小生成树: (0,1,1) + (1,2,2) + (1,3,4) = 7

C++ 实现

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

struct Edge {
    int u, v, weight;
    bool operator<(const Edge& other) const {
        return weight < other.weight;
    }
};

class DisjointSets {
    vector<int> parent;
    vector<int> rank_;

public:
    DisjointSets(int n) : parent(n), rank_(n, 0) {
        for (int i = 0; i < n; i++)
            parent[i] = i;
    }

    int find(int x) {
        if (parent[x] != x)
            parent[x] = find(parent[x]);
        return parent[x];
    }

    void unionSets(int x, int y) {
        int rx = find(x), ry = find(y);
        if (rx == ry) return;
        if (rank_[rx] < rank_[ry]) parent[rx] = ry;
        else if (rank_[rx] > rank_[ry]) parent[ry] = rx;
        else { parent[ry] = rx; rank_[rx]++; }
    }

    bool connected(int x, int y) { return find(x) == find(y); }
};

int kruskal(int n, vector<Edge>& edges) {
    sort(edges.begin(), edges.end());
    DisjointSets ds(n);

    int mstWeight = 0;
    int edgeCount = 0;

    for (auto& e : edges) {
        if (!ds.connected(e.u, e.v)) {
            ds.unionSets(e.u, e.v);
            mstWeight += e.weight;
            edgeCount++;
            cout << "  Edge (" << e.u << ", " << e.v << ") weight="
                 << e.weight << " -> MST total: " << mstWeight << endl;
            if (edgeCount == n - 1) break;
        } else {
            cout << "  Edge (" << e.u << ", " << e.v << ") weight="
                 << e.weight << " -> skipped (cycle)" << endl;
        }
    }

    return mstWeight;
}

int main() {
    int n = 4;
    vector<Edge> edges = {
        {0, 1, 1}, {0, 2, 3}, {1, 2, 2}, {1, 3, 4}, {2, 3, 5}
    };

    cout << "Kruskal's MST:" << endl;
    int total = kruskal(n, edges);
    cout << "Total MST weight: " << total << endl;

    return 0;
}

C 实现

#include <stdio.h>
#include <stdlib.h>

typedef struct {
    int u, v, weight;
} Edge;

typedef struct {
    int* parent;
    int* rank;
    int n;
} DisjointSets;

void ds_init(DisjointSets* ds, int n) {
    ds->n = n;
    ds->parent = (int*)malloc(n * sizeof(int));
    ds->rank = (int*)calloc(n, sizeof(int));
    for (int i = 0; i < n; i++)
        ds->parent[i] = i;
}

void ds_free(DisjointSets* ds) {
    free(ds->parent);
    free(ds->rank);
}

int ds_find(DisjointSets* ds, int x) {
    if (ds->parent[x] != x)
        ds->parent[x] = ds_find(ds, ds->parent[x]);
    return ds->parent[x];
}

void ds_union(DisjointSets* ds, int x, int y) {
    int rx = ds_find(ds, x), ry = ds_find(ds, y);
    if (rx == ry) return;
    if (ds->rank[rx] < ds->rank[ry]) ds->parent[rx] = ry;
    else if (ds->rank[rx] > ds->rank[ry]) ds->parent[ry] = rx;
    else { ds->parent[ry] = rx; ds->rank[rx]++; }
}

int ds_connected(DisjointSets* ds, int x, int y) {
    return ds_find(ds, x) == ds_find(ds, y);
}

int cmp_edge(const void* a, const void* b) {
    return ((Edge*)a)->weight - ((Edge*)b)->weight;
}

int kruskal(int n, Edge* edges, int edgeCount) {
    qsort(edges, edgeCount, sizeof(Edge), cmp_edge);

    DisjointSets ds;
    ds_init(&ds, n);

    int mstWeight = 0;
    int selected = 0;

    for (int i = 0; i < edgeCount && selected < n - 1; i++) {
        if (!ds_connected(&ds, edges[i].u, edges[i].v)) {
            ds_union(&ds, edges[i].u, edges[i].v);
            mstWeight += edges[i].weight;
            selected++;
            printf("  Edge (%d, %d) weight=%d -> MST total: %d\n",
                   edges[i].u, edges[i].v, edges[i].weight, mstWeight);
        } else {
            printf("  Edge (%d, %d) weight=%d -> skipped (cycle)\n",
                   edges[i].u, edges[i].v, edges[i].weight);
        }
    }

    ds_free(&ds);
    return mstWeight;
}

int main() {
    Edge edges[] = {
        {0, 1, 1}, {0, 2, 3}, {1, 2, 2}, {1, 3, 4}, {2, 3, 5}
    };
    int n = 4;
    int edgeCount = sizeof(edges) / sizeof(edges[0]);

    printf("Kruskal's MST:\n");
    int total = kruskal(n, edges, edgeCount);
    printf("Total MST weight: %d\n", total);

    return 0;
}

Python 实现

class DisjointSets:
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n

    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        rx, ry = self.find(x), self.find(y)
        if rx == ry:
            return
        if self.rank[rx] < self.rank[ry]:
            self.parent[rx] = ry
        elif self.rank[rx] > self.rank[ry]:
            self.parent[ry] = rx
        else:
            self.parent[ry] = rx
            self.rank[rx] += 1

    def connected(self, x, y):
        return self.find(x) == self.find(y)


def kruskal(n, edges):
    """Kruskal's MST algorithm. edges: list of (u, v, weight)."""
    edges = sorted(edges, key=lambda e: e[2])
    ds = DisjointSets(n)
    mst_weight = 0
    edge_count = 0

    for u, v, w in edges:
        if not ds.connected(u, v):
            ds.union(u, v)
            mst_weight += w
            edge_count += 1
            print(f"  Edge ({u}, {v}) weight={w} -> MST total: {mst_weight}")
            if edge_count == n - 1:
                break
        else:
            print(f"  Edge ({u}, {v}) weight={w} -> skipped (cycle)")

    return mst_weight


edges = [(0, 1, 1), (0, 2, 3), (1, 2, 2), (1, 3, 4), (2, 3, 5)]
n = 4

print("Kruskal's MST:")
total = kruskal(n, edges)
print(f"Total MST weight: {total}")

Go 实现

package main

import (
	"fmt"
	"sort"
)

type Edge struct {
	U, V, Weight int
}

type DisjointSets struct {
	parent []int
	rank   []int
}

func NewDisjointSets(n int) *DisjointSets {
	parent := make([]int, n)
	rank := make([]int, n)
	for i := range parent {
		parent[i] = i
	}
	return &DisjointSets{parent: parent, rank: rank}
}

func (ds *DisjointSets) Find(x int) int {
	if ds.parent[x] != x {
		ds.parent[x] = ds.Find(ds.parent[x])
	}
	return ds.parent[x]
}

func (ds *DisjointSets) Union(x, y int) {
	rx, ry := ds.Find(x), ds.Find(y)
	if rx == ry {
		return
	}
	if ds.rank[rx] < ds.rank[ry] {
		ds.parent[rx] = ry
	} else if ds.rank[rx] > ds.rank[ry] {
		ds.parent[ry] = rx
	} else {
		ds.parent[ry] = rx
		ds.rank[rx]++
	}
}

func (ds *DisjointSets) Connected(x, y int) bool {
	return ds.Find(x) == ds.Find(y)
}

func kruskal(n int, edges []Edge) int {
	sort.Slice(edges, func(i, j int) bool {
		return edges[i].Weight < edges[j].Weight
	})

	ds := NewDisjointSets(n)
	mstWeight := 0
	edgeCount := 0

	for _, e := range edges {
		if !ds.Connected(e.U, e.V) {
			ds.Union(e.U, e.V)
			mstWeight += e.Weight
			edgeCount++
			fmt.Printf("  Edge (%d, %d) weight=%d -> MST total: %d\n",
				e.U, e.V, e.Weight, mstWeight)
			if edgeCount == n-1 {
				break
			}
		} else {
			fmt.Printf("  Edge (%d, %d) weight=%d -> skipped (cycle)\n",
				e.U, e.V, e.Weight)
		}
	}

	return mstWeight
}

func main() {
	edges := []Edge{
		{0, 1, 1}, {0, 2, 3}, {1, 2, 2}, {1, 3, 4}, {2, 3, 5},
	}
	n := 4

	fmt.Println("Kruskal's MST:")
	total := kruskal(n, edges)
	fmt.Printf("Total MST weight: %d\n", total)
}

运行该程序将输出:

Kruskal's MST:
  Edge (0, 1) weight=1 -> MST total: 1
  Edge (1, 2) weight=2 -> MST total: 3
  Edge (0, 2) weight=3 -> skipped (cycle)
  Edge (1, 3) weight=4 -> MST total: 7
Total MST weight: 7

不相交集合的性质

复杂度

实现方式 Find Union m 次操作
朴素 O(n) 最坏 O(n) 最坏 O(m·n)
按秩合并 O(log n) O(log n) O(m·log n)
路径压缩 O(log n) 均摊 O(log n) 均摊 O(m·log n)
按秩合并 + 路径压缩 O(α(n)) O(α(n)) O(m·α(n))

其中 α(n) 是反阿克曼函数,增长极其缓慢:

α(n) 的值:
  n ≤ 2          → α(n) = 1
  n ≤ 4          → α(n) = 2
  n ≤ 16         → α(n) = 3
  n ≤ 2^65536    → α(n) = 4
  n > 2^65536    → α(n) ≥ 5 (实际不可能遇到)

在实际应用中,α(n) 可以视为常数 4,因此 m 次操作的总时间几乎就是 O(m)。

按秩合并 vs 按大小合并

策略 含义 树高度上界
按秩合并(Union by Rank) rank 是树高度的上界估计 O(log n)
按大小合并(Union by Size) 总是将小集挂到大集下面 O(log n)

两种策略的效果相同,都保证树的高度不超过 O(log n)。按大小合并更直观,按秩合并在路径压缩下更高效。

关键性质

性质 说明
不相交性 任何时刻,每个元素恰好属于一个集合
代表元素唯一 同一集合中所有元素的 Find 结果相同
不可分割 Union 后两个集合永久合并,无法拆分
路径压缩不改变集合 只改变树的结构,不改变集合的划分

应用场景

应用 说明
Kruskal 最小生成树 判断加边是否形成环
动态连通性 判断图中两点是否连通
等价类划分 将等价元素归为同一集合
图像处理 连通区域标记(Connected Component Labeling)
网络协议 检测网络中的冗余连接
最近公共祖先(LCA) Tarjan 离线 LCA 算法的核心数据结构
并行计算 动态处理器分配
posted @ 2026-04-18 00:33  游翔  阅读(8)  评论(0)    收藏  举报