make_blobs 数据生成与可视化代码详解

make_blobs 数据生成与可视化代码详解


完整代码

# 导入绘图工具
import matplotlib.pyplot as plt

# 导入numpy
import numpy as np

# 导入make_blobs函数(这是缺失的部分)
from sklearn.datasets import make_blobs

# 生成样本数为200、分类数为2的数据集
data = make_blobs(n_samples=200, centers=2, random_state=8)

# 大写的X表示数据的特征,小写的y表示数据对应的标签
X, y = data

# 将生成的数据集进行可视化
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.autumn, edgecolor='k')

# 添加图表标题
plt.title('Generated Data with make_blobs')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.colorbar(label='Class')

# 显示图表
plt.show()

代码逐行解析

1. 导入必要的库

import matplotlib.pyplot as plt  # 数据可视化库
import numpy as np              # 数值计算库
from sklearn.datasets import make_blobs  # 生成聚类数据集

说明

  • matplotlib.pyplot:Python最常用的绘图库
  • numpy:科学计算基础库
  • make_blobs:scikit-learn中用于生成聚类数据集的函数

2. 生成数据集

data = make_blobs(n_samples=200, centers=2, random_state=8)

参数解释

参数 说明
n_samples 200 生成的总样本数
centers 2 聚类中心数量(即类别数)
random_state 8 随机种子,保证每次生成相同的数据

返回值

  • data 是一个元组:(X, y)
  • X:特征矩阵,形状为 (200, 2)
  • y:标签数组,形状为 (200,)

3. 解包数据

X, y = data

变量含义

  • X:二维数组,每一行是一个样本,每一列是一个特征
  • y:一维数组,表示每个样本所属的类别(0或1)

数据结构示例

X.shape = (200, 2)  # 200个样本,每个样本2个特征
y.shape = (200,)    # 200个标签

X[0] = [2.5, 3.1]   # 第0个样本的特征值
y[0] = 0            # 第0个样本属于类别0

4. 可视化数据

plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.autumn, edgecolor='k')

参数详解

参数 作用
X[:, 0] 第一列特征 作为x坐标
X[:, 1] 第二列特征 作为y坐标
c y 根据标签设置点的颜色
cmap plt.cm.autumn 使用autumn颜色映射(红黄色系)
edgecolor 'k' 点的边缘颜色为黑色

颜色映射说明

  • autumn colormap:从黄色到红色的渐变
  • 类别0显示为黄色
  • 类别1显示为红色

输出效果

执行代码后,你将看到:

  1. 散点图:200个数据点分布在二维平面上
  2. 颜色区分
    • 黄色点:类别0
    • 红色点:类别1
  3. 聚集特征:同类别数据点会聚集在一起(形成2个簇)

预期效果

    ↑
 5  |    ●●●●●●●●●●●●
    |   ●●●●●●●●●●●●●●●
 4  |  ●●●●●●●●●●●●●●●●
    |  ●●●●●●●●●●●●●●●●●●
 3  | ●●●●●●●●●●●●●●●●●●●●
    | ●●●●●●●●●●●●●●●●●●●●
 2  |  ●●●●●●●●●●●●●●●●●●●
    |   ●●●●●●●●●●●●●●●●
 1  |    ●●●●●●●●●●●●●●●
    |     ●●●●●●●●●●●●●●
 0  |      ●●●●●●●●●●●●
    +--------------------------→
      0   2   4   6   8   10
    (黄色簇和红色簇明显分开)

扩展示例

示例1:生成3个聚类中心

# 生成3个类别
data = make_blobs(n_samples=300, centers=3, random_state=42)
X, y = data

# 可视化
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='viridis', edgecolor='k')
plt.title('Three Clusters')
plt.colorbar(label='Class')
plt.show()

示例2:调整聚类分布

# 生成更分散或更紧凑的聚类
data = make_blobs(n_samples=200,
                   centers=2,
                   cluster_std=2.0,  # 聚类的标准差,越大越分散
                   random_state=8)
X, y = data

plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.autumn, edgecolor='k')
plt.title('Clusters with Different Spread')
plt.show()

示例3:生成高维数据

# 生成3维特征的数据
data = make_blobs(n_samples=200,
                   centers=3,
                   n_features=3,  # 3个特征
                   random_state=42)
X, y = data

print(f"数据形状: {X.shape}")
print(f"标签形状: {y.shape}")
print(f"类别数量: {len(np.unique(y))}")

常用参数汇总

make_blobs 完整参数

make_blobs(
    n_samples=100,         # 总样本数
    n_features=2,          # 每个样本的特征数
    centers=None,          # 聚类中心数量或中心坐标
    cluster_std=1.0,       # 每个聚类的标准差
    center_box=(-10.0, 10.0),  # 聚类中心的边界
    shuffle=True,          # 是否打乱样本顺序
    random_state=None,     # 随机种子
    return_centers=False   # 是否返回聚类中心
)

scatter 常用参数

plt.scatter(
    x,                     # x坐标
    y,                     # y坐标
    s=None,                # 点的大小
    c=None,                # 点的颜色或标签
    marker='o',            # 点的形状
    cmap=None,             # 颜色映射
    alpha=None,            # 透明度
    edgecolors=None,       # 边缘颜色
    linewidths=None        # 边缘线宽
)

实用技巧

1. 查看生成的数据

data = make_blobs(n_samples=200, centers=2, random_state=8)
X, y = data

print(f"特征矩阵 X:\n{X[:5]}")  # 打印前5个样本
print(f"\n标签 y:\n{y[:5]}")     # 打印前5个标签
print(f"\n类别分布: {np.bincount(y)}")  # 统计各类别数量

2. 添加图例

plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.autumn, edgecolor='k')
plt.title('Cluster Visualization')

# 添加图例
from matplotlib.colors import ListedColormap
colors = ['red', 'yellow']
cmap = ListedColormap(colors)
for i, color in enumerate(colors):
    plt.scatter([], [], c=color, label=f'Class {i}')

plt.legend()
plt.show()

3. 保存图片

plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.autumn, edgecolor='k')
plt.savefig('clusters.png', dpi=300, bbox_inches='tight')
plt.show()

应用场景

make_blobs 常用于:

  1. 算法测试:测试聚类算法(如K-Means)
  2. 分类器评估:测试分类算法性能
  3. 数据可视化教学:演示数据分布
  4. 原型开发:快速生成测试数据
  5. 机器学习实验:对比不同算法效果

完整可运行示例

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_blobs

# 设置中文字体(可选)
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# 生成数据
print("正在生成数据集...")
data = make_blobs(n_samples=200, centers=2, random_state=8)
X, y = data

print(f"数据集信息:")
print(f"- 样本数:{X.shape[0]}")
print(f"- 特征数:{X.shape[1]}")
print(f"- 类别数:{len(np.unique(y))}")
print(f"- 类别分布:{np.bincount(y)}")

# 创建图形
plt.figure(figsize=(10, 6))

# 绘制散点图
scatter = plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.autumn,
                      edgecolor='k', s=60, alpha=0.8)

# 添加图表元素
plt.title('make_blobs生成的聚类数据', fontsize=14, pad=20)
plt.xlabel('特征 1 (Feature 1)', fontsize=12)
plt.ylabel('特征 2 (Feature 2)', fontsize=12)
plt.grid(True, linestyle='--', alpha=0.3)

# 添加颜色条
cbar = plt.colorbar(scatter)
cbar.set_label('类别标签', fontsize=12)

# 添加文本说明
plt.text(0.02, 0.98,
         f'总样本数: {len(y)}\n类别0: {np.sum(y==0)}\n类别1: {np.sum(y==1)}',
         transform=plt.gca().transAxes,
         verticalalignment='top',
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

# 显示图表
plt.tight_layout()
plt.savefig('make_blobs_example.png', dpi=300, bbox_inches='tight')
print("\n图表已保存为 'make_blobs_example.png'")
plt.show()

总结

这段代码的核心流程:

  1. 导入库 → 导入必要的工具包
  2. 生成数据 → 使用make_blobs创建聚类数据
  3. 解包数据 → 将特征和标签分离
  4. 可视化 → 用散点图展示数据分布

关键点

  • make_blobs是生成合成数据的强大工具
  • 通过random_state保证结果可重现
  • c参数根据标签自动着色
  • cmap控制颜色方案

尝试

  • 实验不同参数观察数据分布变化
  • 尝试不同的colormap(如'viridis', 'plasma', 'coolwarm')
  • 结合聚类算法(如K-Means)进行实践

posted @ 2026-02-27 16:52  kkman2000  阅读(6)  评论(0)    收藏  举报