KNN原理以及代码实现——以iris为例
相关的实验代码在我的github上👉QYHcrossover/ML-numpy: 机器学习算法numpy实现 (github.com) 欢迎star⭐
当今的机器学习模型需要处理各种类型的数据,如文本、图像、数值等。KNN(K-Nearest Neighbor)是一种经典的分类算法,常用于处理数值型数据。本文将介绍KNN的原理和代码实现,并使用Iris数据集演示其使用。
KNN原理
KNN是一种基于实例的学习算法,其原理是根据数据集中所有数据之间的距离,找到距离最近的k个数据点,并将它们的标签赋给待分类的样本。一般情况下,k的取值为奇数,这样可以避免出现平局的情况。
KNN算法的实现步骤如下:
- 计算待分类数据点和所有训练数据点之间的距离(通常使用欧氏距离);
- 选出距离最近的k个数据点;
- 统计k个数据点的标签,并选出出现次数最多的标签作为待分类数据点的预测标签。
KNN代码实现
接下来,我们将通过Python代码实现KNN算法。首先,我们需要导入numpy、collections、sklearn等必要的库。
import numpy as np
from collections import Counter
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
然后,我们创建一个名为KNNclassifier的类,该类包含以下三个方法:
__init__(self,k):构造函数,需要传入一个参数k,表示选取k个最近邻居。它也对k的值进行了检验,确保k的值大于等于1。fit(self,X_train,y_train):X_train是训练数据集的特征值,y_train是训练数据集的标签;注意这里仅仅是将训练数据集的特征数据和标签数据存储在self.X_train和self.y_train中。_predict(self,X):使用模型对新数据进行预测,其中X是待分类数据点的特征值。首先计算待分类数据点和所有训练数据点之间的距离;其次选出距离最近的k个数据点;最后统计k个数据点的标签,并选出出现次数最多的标签作为待分类数据点的预测标签。
class KNNclassifier:
def __init__(self,k):
assert k>=1,"k必须是大于1的整数"
self.k=k
self.X_train=None
self.y_train=None
def fit(self,X_train,y_train):
self.X_train=X_train
self.y_train=y_train
return self
def _predict(self,X):
distance = [np.sum((X_i-X)**2) for X_i in self.X_train]
arg = np.argsort(distance)
top_k = [self.y_train[i] for i in arg[:self.k]]
counter={}
for i in top_k:
counter.setdefault(i,0)
counter[i]+=1
sort=sorted(counter.items(),key=lambda x:x[1])
return sort[0][0]
最后,我们定义两个方法:predict和score,用于使用训练好的模型进行预测和评估预测结果的准确性。
def predict(self, X_test):
y_predict = [self._predict(i) for i in X_test]
return np.array(y_predict)
def score(self, X_test ,y_test):
y_predict = self.predict(X_test)
return np.sum(y_predict==y_test)/len(X_test)
在主函数部分,首先加载iris数据集,并将其分为训练集和测试集。然后,构造了一个KNN分类器,并用训练数据集拟合模型。接着,用测试数据集对模型进行预测并将预测结果和真实结果对比,得到准确率。
总结
KNN算法是一个简单且有效的分类算法。它的原理是基于距离度量的分类方法,通过计算新数据点与已知数据点之间的距离来进行分类。在实现KNN算法时,需要确定K值和距离度量方法。K值通常是通过交叉验证来确定的。距离度量方法通常使用欧氏距离或曼哈顿距离等。
在本文中,我们使用Python实现了KNN算法,并使用iris数据集对其进行了测试。我们还介绍了KNN算法的原理和代码解释,并提供了KNN算法的代码实现。在下一篇文章中,我们将使用KDTree的数据结构和算法来优化查找最近邻的过程。
相关的实验代码在我的github上,👉QYHcrossover/ML-numpy: 机器学习算法numpy实现 (github.com) 欢迎star⭐