机器学习之K近邻算法实现
import operator
from numpy import array, tile
def create_dataset():
_dataset = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
_labels = ['A', 'A', 'B', 'B']
return _dataset, _labels
def classify(x: list, dataset: array, labels, k):
"""
训练步骤:
(1)计算输入点与样本数据点的距离
(2)按照距离排序
(3)获取距离最小的前k个点
(4)确定前k个点在类别中出现的频率
(5)返回前k个点出现频率最高的类别作为输入点的预测分类
:param x: 用于训练的的数据,两个元素的列表
:param dataset: 样本数据集
:param labels: 标签向量
:param k: 最近邻居数
:return:
"""
# 获取数据形状
dataset_size = dataset.shape[0]
# 求差
diff_mat = tile(x, (dataset_size, 1)) - dataset
# 求平方差
sq_diff_mat = diff_mat ** 2
# 求平方差的和
sq_distance = sq_diff_mat.sum(axis=1)
# 求距离
distances = sq_distance ** 0.5
# 排序
sorted_distances = distances.argsort()
# 分类统计,用于计算前k个最近距离
class_count = {}
for i in range(k):
# 获取标签
vote_label = labels[sorted_distances[i]]
# 累加统计标签个数
class_count[vote_label] = class_count.get(vote_label, 0) + 1
# 对标签出现的次数排序
sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
return sorted_class_count[0][0]
if __name__ == '__main__':
dataset, labels = create_dataset()
print(classify([0, 0], dataset, labels, 3))
其他knn示例或者基于主流机器学习框架实现的knn代码地址:
https://gitee.com/navysummer/machine-learning/tree/master/knn

浙公网安备 33010602011771号