从0开始的机器学习——knn算法篇(4)
本次实验采用另一个数据集——手写字母数据集
首先引入必要的库:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from sklearn import datasets
digits = datasets.load_digits()
digits.keys()
print(digits.DESCR) //看一下这个数据集的描述
.. _digits_dataset:
Optical recognition of handwritten digits dataset
--------------------------------------------------
**Data Set Characteristics:**
:Number of Instances: 5620
:Number of Attributes: 64
:Attribute Information: 8x8 image of integer pixels in the range 0..16.
:Missing Attribute Values: None
:Creator: E. Alpaydin (alpaydin '@' boun.edu.tr)
:Date: July; 1998
This is a copy of the test set of the UCI ML hand-written digits datasets
http://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits
The data set contains images of hand-written digits: 10 classes where
each class refers to a digit.
Preprocessing programs made available by NIST were used to extract
normalized bitmaps of handwritten digits from a preprinted form. From a
total of 43 people, 30 contributed to the training set and different 13
to the test set. 32x32 bitmaps are divided into nonoverlapping blocks of
4x4 and the number of on pixels are counted in each block. This generates
an input matrix of 8x8 where each element is an integer in the range
0..16. This reduces dimensionality and gives invariance to small
distortions.
For info on NIST preprocessing routines, see M. D. Garris, J. L. Blue, G.
T. Candela, D. L. Dimmick, J. Geist, P. J. Grother, S. A. Janet, and C.
L. Wilson, NIST Form-Based Handprint Recognition System, NISTIR 5469,
1994.
.. topic:: References
- C. Kaynak (1995) Methods of Combining Multiple Classifiers and Their
Applications to Handwritten Digit Recognition, MSc Thesis, Institute of
Graduate Studies in Science and Engineering, Bogazici University.
- E. Alpaydin, C. Kaynak (1998) Cascading Classifiers, Kybernetika.
- Ken Tang and Ponnuthurai N. Suganthan and Xi Yao and A. Kai Qin.
Linear dimensionalityreduction using relevance weighted LDA. School of
Electrical and Electronic Engineering Nanyang Technological University.
2005.
- Claudio Gentile. A New Approximate Maximal Margin Classification
Algorithm. NIPS. 2000.
X = digits.data
X.shape //这个数据集是简化的数据集,所以并没有5620个数据,有1797个数据 每个数据有64个属性,是一个8x8的矩阵

查看一下前100个数据的属性:

可以发现这个数据集和鸢尾花的数据集分布不一样,这个是没有规律的。
随意选一个数据看一下:

基本看出来是一个数字 8
接下来调用封装好的knn算法来测试一下:
from sklearn.model_selection import train_test_split //引入分割数据集的方法
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2) //分割数据集
from sklearn.neighbors import KNeighborsClassifier //引入KNN算法
my_knn_clf = KNeighborsClassifier(n_neighbors=3) //k值为3
my_knn_clf.fit(X_train,y_train) //传入训练样本集
y_predict = my_knn_clf.predict(X_test)//获得预测样本数据
y_predict
sum(y_predict == y_test) / len(y_test) # y_predict向量与y_test向量进行比较,如果对应的数值相等,就返回true值,用sum()统计true值的个数,然后比上所有的测试数值个数,就可以获得预测的精确度
如果不想写这个逻辑,可以直接调用sklearn库中的方法:
from sklearn.metrics import accuracy_score
accuracy_score(y_test,y_predict)
my_knn_clf.fit(X_test,y_test)
my_knn_clf.score(X_test,y_test)

浙公网安备 33010602011771号