OpenCV入门(二十九)快速学会OpenCV 28 K均值聚类

作者:Xiou

1.K均值聚类概述

当我们要预测的是一个离散值时,做的工作就是“分类”。例如,要预测一个孩子能否成为优秀的运动员,其实就是要将他分到“好苗子”(能成为优秀的运动员)或“普通孩子”(不能成为优秀运动员)的类别。当我们要预测的是一个连续值时,做的工作就是“回归”。例如,预测一个孩子将来成为运动员的指数,计算得到的是0.99或者0.36之类的数值。

机器学习模型还可以将训练集中的数据划分为若干个组,每个组被称为一个“簇(cluster)”。这些自动形成的簇,可能对应着不同的潜在概念,例如“篮球苗子”、“长跑苗子”。这种学习方式被称为“聚类(clusting)”,它的重要特点是在学习过程中不需要用标签对训练样本进行标注。也就是说,学习过程能够根据现有训练集自动完成分类(聚类)。

根据训练数据是否有标签,我们可以将学习划分为监督学习和无监督学习。前面介绍的K近邻、支持向量机都是监督学习,提供有标签的数据给算法学习,然后对数据分类。而聚类是无监督学习,事先并不知道分类标签是什么,直接对数据分类。

举一个简单的例子,有100粒豆子,如果已知其中40粒为绿豆,40粒为大豆,根据上述标签,将剩下的20粒豆子划分为绿豆和大豆则是监督学习。针对上述问题可以使用K近邻算法,计算当前待分类豆子的大小,并找出距离其最近的5粒豆子的大小,判断这5粒豆子中哪种豆子最多,将当前豆子判定为数量最多的那一类豆子类别。

同样,有100粒豆子,我们仅仅知道这些豆子里有两个不同的品种,但并不知道到底是什么品种。此时,可以根据豆子的大小、颜色属性,或者根据大小和颜色的组合属性,将其划分为两个类型。在此过程中,我们没有使用已知标签,也同样完成了分类,此时的分类是一种无监督学习。

聚类是一种无监督学习,它能够将具有相似属性的对象划分到同一个集合(簇)中。聚类方法能够应用于所有对象,簇内的对象越相似,聚类算法的效果越好。

1.1 K均值聚类的基本步骤

K均值聚类是一种将输入数据划分为k个簇的简单的聚类算法,该算法不断提取当前分类的中心点(也称为质心或重心),并最终在分类稳定时完成聚类。从本质上说,K均值聚类是一种迭代算法。

K均值聚类算法的基本步骤如下:
1.随机选取k个点作为分类的中心点。
2.将每个数据点放到距离它最近的中心点所在的类中。
3.重新计算各个分类的数据点的平均值,将该平均值作为新的分类中心点。
4.重复步骤2和步骤3,直到分类稳定。

在第1步中,可以是随机选取k个点作为分类的中心点,也可以是随机生成k个并不存在于原始数据中的数据点作为分类中心点。在第3步中,提到的“距离最近”,说明要进行某种形式的距离计算。在具体实现时,可以根据需要采用不同形式的距离度量方法。当然,不同的计算方法会对算法的性能产生影响。

1.2 K均值聚类模块

OpenCV提供了函数cv2.kmeans()来实现K均值聚类。该函数的语法格式为:

        retval, bestLabels, centers=cv2.kmeans(data, K, bestLabels, criteria, attempts,
    flags)

式中各个参数的含义为:
● data:输入的待处理数据集合,应该是np.float32类型,每个特征放在单独的一列中。
● K:要分出的簇的个数,即分类的数目,最常见的是K=2,表示二分类。
● bestLabels:表示计算之后各个数据点的最终分类标签(索引)。实际调用时,参数bestLabels的值设置为None。
● criteria:算法迭代的终止条件。当达到最大循环数目或者指定的精度阈值时,算法停止继续分类迭代计算。该参数由3个子参数构成,分别为type、max_iter和eps。
type表示终止的类型,可以是三种情况,分别为:

	● cv2.TERM_CRITERIA_EPS:精度满足eps时,停止迭代。
	● cv2.TERM_CRITERIA_MAX_ITER:迭代次数超过阈值max_iter时,停止迭代。
	● cv2.TERM_CRITERIA_EPS +cv2.TERM_CRITERIA_MAX_ITER:上述两个条件中的任意一个满足时,停止迭代。

● max_iter:最大迭代次数。
● eps:精确度的阈值。
● attempts:在具体实现时,为了获得最佳分类效果,可能需要使用不同的初始分类值进行多次尝试。指定attempts的值,可以让算法使用不同的初始值进行多次(attempts次)尝试。
● flags:表示选择初始中心点的方法,主要有以下3种。

	● cv2.KMEANS_RANDOM_CENTERS:随机选取中心点。
	● cv2.KMEANS_PP_CENTERS:基于中心化算法选取中心点。
	●cv2.KMEANS_USE_INITIAL_LABELS:使用用户输入的数据作为第一次分类中心点;如果算法需要尝试多次(attempts 值大于1时),后续尝试都是使用随机值或者半随机值作为第一次分类中心点。

返回值的含义为:
● retval:距离值(也称密度值或紧密度),返回每个点到相应中心点距离的平方和。
● bestLabels:各个数据点的最终分类标签(索引)。
● centers:每个分类的中心点数据。

2.操作实例

随机生成一组数据,使用函数cv2.kmeans()对其分类。

为了方便理解,假设有两种豆子,其中一种是“xiaoMI”,另外一种是“daMI”。它们的直径不一样,xiaoMI的直径在[0, 50]区间;daMI的直径在[200, 250]区间。用随机数模拟两种豆子的直径,并使用函数cv2.kmeans()对它们分类。

2.1 数据预处理

使用随机函数随机生成两组豆子的直径数据,并将它们转换为函数cv2.kmeans()可以处理的格式。

2.2 设置参数

设置函数cv2.kmeans()的参数形式。将参数criteria的值设置为“(cv2.TERM_CRITERIA_EPS+cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)”,在达到一定次数或者满足一定精度时终止迭代。

2.3 调用函数cv2.kmeans()

调用函数cv2.kmeans(),获取返回值,用于后续步骤的操作。

2.4 确定分类

根据函数cv2.kmeans()返回的标签(“0”和“1”),将原始数据分为两组。

2.5 显示结果

绘制经过分类的数据及中心点,观察分类结果。

代码(1)如下:

        import numpy as np
        import cv2
        from matplotlib import pyplot as plt
        # 随机生成两组数组
        # 生成60个值在[0,50]内的xiaoMI直径数据
        xiaoMI = np.random.randint(0,50,60)
        # 生成60个值在[200,250]内的daMI直径数据
        daMI = np.random.randint(200,250,60)
        # 将xiaoMI和daMI组合为MI
        MI = np.hstack((xiaoMI, daMI))
        # 使用reshape函数将其转换为(120,1)
        MI = MI.reshape((120,1))
        # 将MI转换为float32类型
        MI = np.float32(MI)
        # 调用kmeans模块
        # 设置参数criteria的值
        criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
        # 设置参数flags的值
        flags = cv2.KMEANS_RANDOM_CENTERS
        # 调用函数kmeans
        retval, bestLabels, centers = cv2.kmeans(MI,2, None, criteria,10, flags)
        '''
        # 打印返回值
        print(retval)
        print(bestLabels)
        print(centers)
        '''
        # 获取分类结果
        XM = MI[bestLabels==0]
        DM = MI[bestLabels==1]
        # 绘制分类结果
        # 绘制原始数据
        plt.plot(XM, 'ro')
        plt.plot(DM, 'bo')
        # 绘制中心点
        plt.plot(centers[0], 'rx')
        plt.plot(centers[1], 'bx')
        plt.show()

输出结果:
在这里插入图片描述

在图中,上面的小方块是标签为“0”的数据点,下方的圆点是标签为“1”的数据点。上方的“x”标记是标签为“0”的数据组的中心点,其值大概在225左右;下方的“x”是标签为“1”的数据组的中心点,其值大概在25左右。在需要时,可以通过print语句打印centers[0]和centers[1]获取两个中心点的值。

代码(2)如下:

        import numpy as np
        import cv2
        from matplotlib import pyplot as plt
        # 随机生成两组数值
        # xiaomi组,长和宽都在[0,20]内
        xiaomi = np.random.randint(0,20, (30,2))
        #dami组,长和宽的大小都在[40,60]
        dami = np.random.randint(40,60, (30,2))
        # 组合数据
        MI = np.vstack((xiaomi, dami))
        # 转换为float32类型
        MI = np.float32(MI)
        # 调用kmeans模块
        # 设置参数criteria值
        criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
        # 调用kmeans函数
        ret, label, center=cv2.kmeans(MI,2, None, criteria,10, cv2.KMEANS_RANDOM_CENTERS)
        '''
        #打印返回值
        print(ret)
        print(label)
        print(center)
        '''
        # 根据kmeans的处理结果,将数据分类,分为XM和DM两大类
        XM = MI[label.ravel()==0]
        DM = MI[label.ravel()==1]
        # 绘制分类结果数据及中心点
        plt.scatter(XM[:,0], XM[:,1], c = 'g', marker = 's')
        plt.scatter(DM[:,0], DM[:,1], c = 'r', marker = 'o')
        plt.scatter(center[0,0], center[0,1], s = 200, c = 'b', marker = 'o')
        plt.scatter(center[1,0], center[1,1], s = 200, c = 'b', marker = 's')
        plt.xlabel('Height'), plt.ylabel('Width')
        plt.show()

输出结果:
在这里插入图片描述

在图中,右上方的小方块是标签为“0”的数据点,左下方的圆点是标签为“1”的数据点。右上方稍大的圆点是标签“0”的数据组的中心点;左下方稍大的方块是标签为“1”的数据组的中心点。

在程序中,“#打印返回值”下面3行的打印语句被注释掉了。在需要时,可以去掉注释,通过print语句打印对应的值。例如,语句“print(center)”可以打印center所表示的两个中心点的值。

代码(3)实例:使用函数cv2.kmeans()将灰度图像处理为只有两个灰度级的二值图像。

        import numpy as np
        import cv2
        import matplotlib.pyplot as plt
        # 读取待处理图像
        img = cv2.imread('lena.bmp')
        # 使用reshape将一个像素点的RGB值作为一个单元处理
        data = img.reshape((-1,3))
        # 转换为kmeans可以处理的类型
        data = np.float32(data)
        # 调用kmeans模块
        criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
        K =2
        ret, label, center=cv2.kmeans(data, K, None, criteria,10, cv2.KMEANS_RANDOM_CENTER
    S)
        # 转换为uint8数据类型,将每个像素点都赋值为当前分类的中心点像素值
        # 将center的值转换为uint8
        center = np.uint8(center)
        # 使用center内的值替换原像素点的值
        res1 = center[label.flatten()]
        # 使用reshape调整替换后的图像
        res2 = res1.reshape((img.shape))
        # 显示处理结果
        plt.subplot(121)
        plt.imshow(img)
        plt.axis('off')
        plt.subplot(122)
        plt.imshow(res2)
        plt.axis('off')

输出结果:
在这里插入图片描述

其中,左图是原始图像,右图是二值化图像。调整程序中的K值,就能改变图像的显示结果。例如,K=8,则可以让图像显示8个灰度级。

posted @ 2023-04-05 13:57  小幽余生不加糖  阅读(99)  评论(0)    收藏  举报  来源