KNN笔记

                                                                                        KNN笔记

先简单加载一下sklearn里的数据集,然后再来讲KNN。

1 import numpy as np
2 import matplotlib as mpl
3 import matplotlib.pyplot as plt
4 from sklearn import datasets
5 iris=datasets.load_iris()

 

看一下鸢尾花的keys:

iris.keys()

结果是:

dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names'])

 

看一下文档:

print(iris.DESCR) #看看文档

文档结果:

Iris Plants Database
====================

Notes
-----
Data Set Characteristics:
    :Number of Instances: 150 (50 in each of three classes)
    :Number of Attributes: 4 numeric, predictive attributes and the class
    :Attribute Information:
        - sepal length in cm
        - sepal width in cm
        - petal length in cm
        - petal width in cm
        - class:
                - Iris-Setosa
                - Iris-Versicolour
                - Iris-Virginica
    :Summary Statistics:

    ============== ==== ==== ======= ===== ====================
                    Min  Max   Mean    SD   Class Correlation
    ============== ==== ==== ======= ===== ====================
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20  0.76     0.9565  (high!)
    ============== ==== ==== ======= ===== ====================

    :Missing Attribute Values: None
    :Class Distribution: 33.3% for each of 3 classes.
    :Creator: R.A. Fisher
    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
    :Date: July, 1988

This is a copy of UCI ML iris datasets.
http://archive.ics.uci.edu/ml/datasets/Iris

The famous Iris database, first used by Sir R.A Fisher

This is perhaps the best known database to be found in the
pattern recognition literature.  Fisher's paper is a classic in the field and
is referenced frequently to this day.  (See Duda & Hart, for example.)  The
data set contains 3 classes of 50 instances each, where each class refers to a
type of iris plant.  One class is linearly separable from the other 2; the
latter are NOT linearly separable from each other.

References
----------
   - Fisher,R.A. "The use of multiple measurements in taxonomic problems"
     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
     Mathematical Statistics" (John Wiley, NY, 1950).
   - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.
     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.
   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
     Structure and Classification Rule for Recognition in Partially Exposed
     Environments".  IEEE Transactions on Pattern Analysis and Machine
     Intelligence, Vol. PAMI-2, No. 1, 67-71.
   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions
     on Information Theory, May 1972, 431-433.
   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II
     conceptual clustering system finds 3 classes in the data.
   - Many, many more ...
文档

 

看一下数据data:

iris.data #看看数据

数据为:

  1 array([[ 5.1,  3.5,  1.4,  0.2],
  2        [ 4.9,  3. ,  1.4,  0.2],
  3        [ 4.7,  3.2,  1.3,  0.2],
  4        [ 4.6,  3.1,  1.5,  0.2],
  5        [ 5. ,  3.6,  1.4,  0.2],
  6        [ 5.4,  3.9,  1.7,  0.4],
  7        [ 4.6,  3.4,  1.4,  0.3],
  8        [ 5. ,  3.4,  1.5,  0.2],
  9        [ 4.4,  2.9,  1.4,  0.2],
 10        [ 4.9,  3.1,  1.5,  0.1],
 11        [ 5.4,  3.7,  1.5,  0.2],
 12        [ 4.8,  3.4,  1.6,  0.2],
 13        [ 4.8,  3. ,  1.4,  0.1],
 14        [ 4.3,  3. ,  1.1,  0.1],
 15        [ 5.8,  4. ,  1.2,  0.2],
 16        [ 5.7,  4.4,  1.5,  0.4],
 17        [ 5.4,  3.9,  1.3,  0.4],
 18        [ 5.1,  3.5,  1.4,  0.3],
 19        [ 5.7,  3.8,  1.7,  0.3],
 20        [ 5.1,  3.8,  1.5,  0.3],
 21        [ 5.4,  3.4,  1.7,  0.2],
 22        [ 5.1,  3.7,  1.5,  0.4],
 23        [ 4.6,  3.6,  1. ,  0.2],
 24        [ 5.1,  3.3,  1.7,  0.5],
 25        [ 4.8,  3.4,  1.9,  0.2],
 26        [ 5. ,  3. ,  1.6,  0.2],
 27        [ 5. ,  3.4,  1.6,  0.4],
 28        [ 5.2,  3.5,  1.5,  0.2],
 29        [ 5.2,  3.4,  1.4,  0.2],
 30        [ 4.7,  3.2,  1.6,  0.2],
 31        [ 4.8,  3.1,  1.6,  0.2],
 32        [ 5.4,  3.4,  1.5,  0.4],
 33        [ 5.2,  4.1,  1.5,  0.1],
 34        [ 5.5,  4.2,  1.4,  0.2],
 35        [ 4.9,  3.1,  1.5,  0.1],
 36        [ 5. ,  3.2,  1.2,  0.2],
 37        [ 5.5,  3.5,  1.3,  0.2],
 38        [ 4.9,  3.1,  1.5,  0.1],
 39        [ 4.4,  3. ,  1.3,  0.2],
 40        [ 5.1,  3.4,  1.5,  0.2],
 41        [ 5. ,  3.5,  1.3,  0.3],
 42        [ 4.5,  2.3,  1.3,  0.3],
 43        [ 4.4,  3.2,  1.3,  0.2],
 44        [ 5. ,  3.5,  1.6,  0.6],
 45        [ 5.1,  3.8,  1.9,  0.4],
 46        [ 4.8,  3. ,  1.4,  0.3],
 47        [ 5.1,  3.8,  1.6,  0.2],
 48        [ 4.6,  3.2,  1.4,  0.2],
 49        [ 5.3,  3.7,  1.5,  0.2],
 50        [ 5. ,  3.3,  1.4,  0.2],
 51        [ 7. ,  3.2,  4.7,  1.4],
 52        [ 6.4,  3.2,  4.5,  1.5],
 53        [ 6.9,  3.1,  4.9,  1.5],
 54        [ 5.5,  2.3,  4. ,  1.3],
 55        [ 6.5,  2.8,  4.6,  1.5],
 56        [ 5.7,  2.8,  4.5,  1.3],
 57        [ 6.3,  3.3,  4.7,  1.6],
 58        [ 4.9,  2.4,  3.3,  1. ],
 59        [ 6.6,  2.9,  4.6,  1.3],
 60        [ 5.2,  2.7,  3.9,  1.4],
 61        [ 5. ,  2. ,  3.5,  1. ],
 62        [ 5.9,  3. ,  4.2,  1.5],
 63        [ 6. ,  2.2,  4. ,  1. ],
 64        [ 6.1,  2.9,  4.7,  1.4],
 65        [ 5.6,  2.9,  3.6,  1.3],
 66        [ 6.7,  3.1,  4.4,  1.4],
 67        [ 5.6,  3. ,  4.5,  1.5],
 68        [ 5.8,  2.7,  4.1,  1. ],
 69        [ 6.2,  2.2,  4.5,  1.5],
 70        [ 5.6,  2.5,  3.9,  1.1],
 71        [ 5.9,  3.2,  4.8,  1.8],
 72        [ 6.1,  2.8,  4. ,  1.3],
 73        [ 6.3,  2.5,  4.9,  1.5],
 74        [ 6.1,  2.8,  4.7,  1.2],
 75        [ 6.4,  2.9,  4.3,  1.3],
 76        [ 6.6,  3. ,  4.4,  1.4],
 77        [ 6.8,  2.8,  4.8,  1.4],
 78        [ 6.7,  3. ,  5. ,  1.7],
 79        [ 6. ,  2.9,  4.5,  1.5],
 80        [ 5.7,  2.6,  3.5,  1. ],
 81        [ 5.5,  2.4,  3.8,  1.1],
 82        [ 5.5,  2.4,  3.7,  1. ],
 83        [ 5.8,  2.7,  3.9,  1.2],
 84        [ 6. ,  2.7,  5.1,  1.6],
 85        [ 5.4,  3. ,  4.5,  1.5],
 86        [ 6. ,  3.4,  4.5,  1.6],
 87        [ 6.7,  3.1,  4.7,  1.5],
 88        [ 6.3,  2.3,  4.4,  1.3],
 89        [ 5.6,  3. ,  4.1,  1.3],
 90        [ 5.5,  2.5,  4. ,  1.3],
 91        [ 5.5,  2.6,  4.4,  1.2],
 92        [ 6.1,  3. ,  4.6,  1.4],
 93        [ 5.8,  2.6,  4. ,  1.2],
 94        [ 5. ,  2.3,  3.3,  1. ],
 95        [ 5.6,  2.7,  4.2,  1.3],
 96        [ 5.7,  3. ,  4.2,  1.2],
 97        [ 5.7,  2.9,  4.2,  1.3],
 98        [ 6.2,  2.9,  4.3,  1.3],
 99        [ 5.1,  2.5,  3. ,  1.1],
100        [ 5.7,  2.8,  4.1,  1.3],
101        [ 6.3,  3.3,  6. ,  2.5],
102        [ 5.8,  2.7,  5.1,  1.9],
103        [ 7.1,  3. ,  5.9,  2.1],
104        [ 6.3,  2.9,  5.6,  1.8],
105        [ 6.5,  3. ,  5.8,  2.2],
106        [ 7.6,  3. ,  6.6,  2.1],
107        [ 4.9,  2.5,  4.5,  1.7],
108        [ 7.3,  2.9,  6.3,  1.8],
109        [ 6.7,  2.5,  5.8,  1.8],
110        [ 7.2,  3.6,  6.1,  2.5],
111        [ 6.5,  3.2,  5.1,  2. ],
112        [ 6.4,  2.7,  5.3,  1.9],
113        [ 6.8,  3. ,  5.5,  2.1],
114        [ 5.7,  2.5,  5. ,  2. ],
115        [ 5.8,  2.8,  5.1,  2.4],
116        [ 6.4,  3.2,  5.3,  2.3],
117        [ 6.5,  3. ,  5.5,  1.8],
118        [ 7.7,  3.8,  6.7,  2.2],
119        [ 7.7,  2.6,  6.9,  2.3],
120        [ 6. ,  2.2,  5. ,  1.5],
121        [ 6.9,  3.2,  5.7,  2.3],
122        [ 5.6,  2.8,  4.9,  2. ],
123        [ 7.7,  2.8,  6.7,  2. ],
124        [ 6.3,  2.7,  4.9,  1.8],
125        [ 6.7,  3.3,  5.7,  2.1],
126        [ 7.2,  3.2,  6. ,  1.8],
127        [ 6.2,  2.8,  4.8,  1.8],
128        [ 6.1,  3. ,  4.9,  1.8],
129        [ 6.4,  2.8,  5.6,  2.1],
130        [ 7.2,  3. ,  5.8,  1.6],
131        [ 7.4,  2.8,  6.1,  1.9],
132        [ 7.9,  3.8,  6.4,  2. ],
133        [ 6.4,  2.8,  5.6,  2.2],
134        [ 6.3,  2.8,  5.1,  1.5],
135        [ 6.1,  2.6,  5.6,  1.4],
136        [ 7.7,  3. ,  6.1,  2.3],
137        [ 6.3,  3.4,  5.6,  2.4],
138        [ 6.4,  3.1,  5.5,  1.8],
139        [ 6. ,  3. ,  4.8,  1.8],
140        [ 6.9,  3.1,  5.4,  2.1],
141        [ 6.7,  3.1,  5.6,  2.4],
142        [ 6.9,  3.1,  5.1,  2.3],
143        [ 5.8,  2.7,  5.1,  1.9],
144        [ 6.8,  3.2,  5.9,  2.3],
145        [ 6.7,  3.3,  5.7,  2.5],
146        [ 6.7,  3. ,  5.2,  2.3],
147        [ 6.3,  2.5,  5. ,  1.9],
148        [ 6.5,  3. ,  5.2,  2. ],
149        [ 6.2,  3.4,  5.4,  2.3],
150        [ 5.9,  3. ,  5.1,  1.8]])
数据data

可见data为150行,每行4列的数据。

 

看一下target:

iris.target #看看对应的目标值

target结果为:

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

 

看一下target_names:

iris.target_names #看看目标值对应的目标名称

arget_names结果为:

array(['setosa', 'versicolor', 'virginica'],
      dtype='<U10')

也就是target的0,1,2分别对应的鸢尾花的名称就是这三个。

 

看一下4列数据(也就是data)分别是指什么

iris.feature_names #看看四个数据对应的是什么

可以看到结果为:

['sepal length (cm)',
 'sepal width (cm)',
 'petal length (cm)',
 'petal width (cm)']

也就是4列数据分别代表花萼的长,花萼的宽,花瓣的长,花瓣的宽。

 

看一下花萼的数据,也就是前两列的数据:

1 #看一下花萼的散点图
2 X=iris.data[:,:2] 
3 plt.scatter(X[:,0],X[:,1])
4 plt.xlabel("sepal length")
5 plt.ylabel("sepal width")
6 plt.title("DU's plot about speal")
7 plt.show() 

  

 

把三种花的散点图区分一下:

1 #把三种花的花萼的散点图画出来
2 y=iris.target
3 plt.scatter(X[y==0,0],X[y==0,1],color='b')
4 plt.scatter(X[y==1,0],X[y==1,1],color='r')
5 plt.scatter(X[y==2,0],X[y==2,1],color='g')
6 plt.xlabel("sepal length")
7 plt.ylabel("sepal width")
8 plt.title("DU's plot about speal")
9 plt.show()

 

再看一下花瓣的散点图:

1 Petal=iris.data[:,2:]
2 y=iris.target
3 plt.scatter(Petal[y==0,0],Petal[y==0,1],color='b')
4 plt.scatter(Petal[y==1,0],Petal[y==1,1],color='r')
5 plt.scatter(Petal[y==2,0],Petal[y==2,1],color='g')
6 plt.xlabel("Petal length")
7 plt.ylabel("Petal width")
8 plt.title("DU's plot about Petal")
9 plt.show()

 

 

看到花瓣的散点图,那么就说一下KNN,那现在假设,花瓣散点图里来了一个长度为2cm,宽度主0.5cm的一个点,那么这个点代表的是哪个鸢尾呢?一般的人就能推出这个点应该是跟蓝色点是一类的,因为新进来的点是离蓝色的区域最近的,而离其他的红色或者绿色区域都很远。那么,这就是KNN的一个思想了。

 

比如现假设有如下场景,模拟有如下数据:

 1 raw_X=[[1,2],
 2        [2.8,2.5],
 3        [4,3.2],
 4        [2,1.5],
 5        [6,7.8],
 6        [8,5],
 7        [9,7],
 8        [7,8.5],
 9        [10,9.7],       
10       ]
11 raw_y=[0,0,0,0,1,1,1,1,1]
12 X_train=np.array(raw_X)
13 y_train=np.array(raw_y)

 

现在有一个数据x(设置为绿色的点)进来了,要判断这个数据是属于哪一类的:

1 x=np.array([7.5,6.5])
2 plt.scatter(X_train[y_train==0,0],X_train[y_train==0,1])
3 plt.scatter(X_train[y_train==1,0],X_train[y_train==1,1],color='r')
4 plt.scatter(x[0],x[1],color='g')
5 plt.show()

 

那么,按照KNN的思路就需求,求出这个里面,所有点离这个绿色点的距离了,看这个绿色的点离哪些是最近的。

那么,根据欧拉距离,一般程序员就可以写出这样的代码了:

1 from math import sqrt
2 distances=[]
3 for x_train in X_train:
4     d=sqrt(np.sum(x_train-x)**2)
5     distances.append(d)

 

当然,根据欧拉距离,不一般的程序员是会这么写:

distances=[sqrt(np.sum(x_train-x)**2) for x_train in X_train]

 

而结果distances都会是:

[11.0, 8.7, 6.8, 10.5, 0.20000000000000018, 1.0, 2.0, 1.5, 5.699999999999999]

 

接着,算出距离最近元素的索引,进而拿到距离最近的值:

1 nearest=np.argsort(distances)
2 topK_y=[y_train[i] for neighbor in nearest[:5]]
3 from collections import Counter
4 votes=Counter(topK_y)
5 predict_y=votes.most_common(1)[0][0]
6 predict_y

结果明显是1。

  

 

 

  

 

posted @ 2018-01-28 01:05  公子若不胖天下谁胖  阅读(258)  评论(0编辑  收藏  举报