Mean shift 向量

中文名叫 均值偏移向量,定义如下:

在一个 n 维空间内,存在一个点 x,以该点为球心,以 h 为半径,生成一个球 Sh,计算球心到球内所有点生成的向量的均值,显然这个均值也是一个向量,这个向量就是 mean shift 向量

公式如下

          【用 球心 计算 质心】

图示如下 

 

 

shift point 

我给起个名字,叫 偏移点;

注意,几乎没有资料专门提到这个概念,我为什么要讲呢?因为我们需要把 shift point 和 mean shift 区分开,这俩可不是一回事;

而在我们写算法时,需要的是 shift point,而不是 mean shift

    def _shift_point(self, point, points, kernel_bandwidth):
        shift_x = 0.0
        shift_y = 0.0
        scale = 0.0
        for p in points:
            dist = distance(point, p)
            weight = self.kernel(dist, kernel_bandwidth)
            if dist == 0: print(weight)
            ### shift point
            shift_x += p[0] * weight
            shift_y += p[1] * weight
            ### 而不是 mean shift
            # shift_x += (p[0] - point[0]) * weight
            # shift_y += (p[1] - point[1]) * weight
            scale += weight
        shift_x = shift_x / scale
        shift_y = shift_y / scale
        return [shift_x, shift_y]

 

那 shift point 到底是什么,又该如何计算呢,直接上图

 

再来一张

 

 

Mean shift 算法 

基本思想

对于 样本中的每一个点 x,做如下操作

1. 以 x 为起点,计算他的 shift point  x‘,然后把 该点 “移动” 到 x’      【注意不是真的移动点,而是把 x 标记成 x’

2. 以 x’ 为新起点,计算他的 shift point

3. 重复 前两步,直至 前后两次 的 mean shift 向量满足条件,如 距离很近  【这一步才用到 mean shift,也就是 前后两个 shift point 相减得到 向量,再计算向量的模】

4. 把 x 标记为 最终的 shift point,即为对应的类

5. 遍历计算所有点

 

过程大致如下图

从上图可以看到,mean shift 向量指向了更密集的区域,也就是说 mean shift 算法是在寻找 最密集 的区域,作为最后的类别

 

存在问题

在计算 mean shift 向量时,圆圈内所有点的贡献是一样的 即1/k,而实际上离圆心越远可能贡献越小,

为此 mean shift 算法引入核函数来表达这种贡献,代替 1/k

 

引入核函数 

核函数 参考 我的博客

 

此处以 高斯核函数 为例

其中 h 代表核函数的带宽(bandwidth)      【这个 h 和 高斯分布 里的 标准差σ 类似,但它不是 标准差,而是 人工指定的,但是起到的作用和 标准差一样】

 

不同带宽的核函数表示如下

 

在 h 一定时, x 离 均值(圆心)越远,函数值越小,体现到 mean shift 向量中,就是 贡献越小;    【高斯滤波不也是这样吗,那高斯滤波也可以引入核函数了】

h 越小,衰减为 0 的速度就越快,也就是说 mean shift 向量对应的球 S越小,稍微远点就没有贡献了

 

于是,引入 核函数 的 mean shift 向量变成如下样子,此时的 S可以为整个数据集(原因为上句)

 

Mean shift VS KMeans

1. KMeans 需要设置 k,mean shift 无需

2. 实际工作中复杂数据用 mean shift 无法控制 k 个值,可能会产生过多的类而导致聚类失去意义

 

示例代码

import random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs, make_moons


# 定义 预先设定 的阈值
STOP_THRESHOLD = 1e-4
CLUSTER_THRESHOLD = 1e-1

# 定义度量函数
def distance(a, b):
    return np.linalg.norm(np.array(a) - np.array(b))

# 定义高斯核函数
def gaussian_kernel(distance, bandwidth):
    return (1 / (bandwidth * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((distance / bandwidth)) ** 2)


# mean_shift类
class mean_shift(object):
    def __init__(self, kernel=gaussian_kernel):
        self.kernel = kernel

    def fit(self, points, kernel_bandwidth):

        shift_points = np.array(points)     # 初始化偏移点
        shifting = [True] * points.shape[0] # 所有点都需要偏移

        ######## while 循环用来执行针对所有点的 多轮偏移
        while True:
            max_dist = 0
            ######## for 循环用来执行针对所有点的 一轮偏移
            for i in range(0, len(shift_points)):
                if not shifting[i]:     # 是否需要偏移
                    continue
                ##### 如果需要偏移,先移动一次,   while 循环保证一直会移动
                p_shift_init = shift_points[i].copy()   # 获取一个点
                ### 下面这句已经把原来的点 偏移到 偏移点了
                shift_points[i] = self._shift_point(shift_points[i], points, kernel_bandwidth)
                dist = distance(shift_points[i], p_shift_init)  # 偏移点 和 原来的点 的距离

                max_dist = max(max_dist, dist)      # 取最大距离
                # 距离大于停止条件,继续移动, 距离 小于 停止条件,结束移动
                shifting[i] = dist > STOP_THRESHOLD 
                ### 至此 一轮偏移 结束
            
            ### 每轮 偏移后 取 所有偏移向量得 最大值,
            ### 如果 小于 停止条件,说明所有点都 偏移到 最后的 点了,多伦偏移可以结束了
            if (max_dist < STOP_THRESHOLD):     
                break
        
        ## shift_points 就是 每个点对应的 最终 shift point
        cluster_ids = self._cluster_points(shift_points.tolist())
        return shift_points, cluster_ids

    def _shift_point(self, point, points, kernel_bandwidth):
        # point 球心, points 球内点,计算 均值偏移向量
        shift_x = 0.0
        shift_y = 0.0
        scale = 0.0
        for p in points:
            print(point)
            dist = distance(point, p)
            weight = self.kernel(dist, kernel_bandwidth)
            # shift point 
            shift_x += p[0] * weight
            shift_y += p[1] * weight
            scale += weight
        shift_x = shift_x / scale
        shift_y = shift_y / scale
        return [shift_x, shift_y]

    def _cluster_points(self, points):
        cluster_ids = []
        cluster_idx = 0
        cluster_centers = []

        for i, point in enumerate(points):
            if (len(cluster_ids) == 0):
                cluster_ids.append(cluster_idx)
                cluster_centers.append(point)
                cluster_idx += 1
            else:
                for center in cluster_centers:
                    dist = distance(point, center)
                    if (dist < CLUSTER_THRESHOLD):
                        cluster_ids.append(cluster_centers.index(center))
                if (len(cluster_ids) < i + 1):
                    cluster_ids.append(cluster_idx)
                    cluster_centers.append(point)
                    cluster_idx += 1
        return cluster_ids

def colors(n):
    ret = []
    for i in range(n):
        ret.append((random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)))
    return ret

def main():
    centers = [[-1, -1], [-1, 1], [1, -1], [1, 1]]
    X, _ = make_blobs(n_samples=300, centers=centers, cluster_std=0.4)  # h=0.5
    # X, y = make_moons(n_samples=200, noise=0.05, random_state=0)        # h

    mean_shifter = mean_shift()
    _, mean_shift_result = mean_shifter.fit(X, kernel_bandwidth=0.2)

    np.set_printoptions(precision=3)
    print('input: {}'.format(X))
    print('assined clusters: {}'.format(mean_shift_result))
    color = colors(np.unique(mean_shift_result).size)

    for i in range(len(mean_shift_result)):
        plt.scatter(X[i, 0], X[i, 1], color=color[mean_shift_result[i]])
    plt.show()


if __name__ == '__main__':
    main()

输出

 

 

 

 

 

 

参考资料:

https://zhuanlan.zhihu.com/p/81629406  机器学习-Mean Shift聚类算法

https://www.biaodianfu.com/mean-shift.html  机器学习聚类算法之Mean Shift

https://www.cnblogs.com/liqizhou/archive/2012/05/12/2497220.html  Meanshift,聚类算法

https://blog.csdn.net/u014661698/article/details/84979979  聚类算法之meanshift

https://blog.csdn.net/moge19/article/details/85346528 

https://www.jb51.net/article/188375.htm  python实现mean-shift聚类算法