In [8]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.preprocessing import StandardScaler
from sklearn.naive_bayes import GaussianNB, MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.neighbors import KNeighborsClassifier
def iris_type(s):
it = {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2}
return it[s]
In [15]:
data = np.loadtxt('D:\\mlInAction\\8.iris.data', encoding='utf-8', dtype=float, delimiter=',',
converters={4: iris_type})
data
Out[15]:
array([[5.1, 3.5, 1.4, 0.2, 0. ],
[4.9, 3. , 1.4, 0.2, 0. ],
[4.7, 3.2, 1.3, 0.2, 0. ],
[4.6, 3.1, 1.5, 0.2, 0. ],
[5. , 3.6, 1.4, 0.2, 0. ],
[5.4, 3.9, 1.7, 0.4, 0. ],
[4.6, 3.4, 1.4, 0.3, 0. ],
[5. , 3.4, 1.5, 0.2, 0. ],
[4.4, 2.9, 1.4, 0.2, 0. ],
[4.9, 3.1, 1.5, 0.1, 0. ],
[5.4, 3.7, 1.5, 0.2, 0. ],
[4.8, 3.4, 1.6, 0.2, 0. ],
[4.8, 3. , 1.4, 0.1, 0. ],
[4.3, 3. , 1.1, 0.1, 0. ],
[5.8, 4. , 1.2, 0.2, 0. ],
[5.7, 4.4, 1.5, 0.4, 0. ],
[5.4, 3.9, 1.3, 0.4, 0. ],
[5.1, 3.5, 1.4, 0.3, 0. ],
[5.7, 3.8, 1.7, 0.3, 0. ],
[5.1, 3.8, 1.5, 0.3, 0. ],
[5.4, 3.4, 1.7, 0.2, 0. ],
[5.1, 3.7, 1.5, 0.4, 0. ],
[4.6, 3.6, 1. , 0.2, 0. ],
[5.1, 3.3, 1.7, 0.5, 0. ],
[4.8, 3.4, 1.9, 0.2, 0. ],
[5. , 3. , 1.6, 0.2, 0. ],
[5. , 3.4, 1.6, 0.4, 0. ],
[5.2, 3.5, 1.5, 0.2, 0. ],
[5.2, 3.4, 1.4, 0.2, 0. ],
[4.7, 3.2, 1.6, 0.2, 0. ],
[4.8, 3.1, 1.6, 0.2, 0. ],
[5.4, 3.4, 1.5, 0.4, 0. ],
[5.2, 4.1, 1.5, 0.1, 0. ],
[5.5, 4.2, 1.4, 0.2, 0. ],
[4.9, 3.1, 1.5, 0.1, 0. ],
[5. , 3.2, 1.2, 0.2, 0. ],
[5.5, 3.5, 1.3, 0.2, 0. ],
[4.9, 3.1, 1.5, 0.1, 0. ],
[4.4, 3. , 1.3, 0.2, 0. ],
[5.1, 3.4, 1.5, 0.2, 0. ],
[5. , 3.5, 1.3, 0.3, 0. ],
[4.5, 2.3, 1.3, 0.3, 0. ],
[4.4, 3.2, 1.3, 0.2, 0. ],
[5. , 3.5, 1.6, 0.6, 0. ],
[5.1, 3.8, 1.9, 0.4, 0. ],
[4.8, 3. , 1.4, 0.3, 0. ],
[5.1, 3.8, 1.6, 0.2, 0. ],
[4.6, 3.2, 1.4, 0.2, 0. ],
[5.3, 3.7, 1.5, 0.2, 0. ],
[5. , 3.3, 1.4, 0.2, 0. ],
[7. , 3.2, 4.7, 1.4, 1. ],
[6.4, 3.2, 4.5, 1.5, 1. ],
[6.9, 3.1, 4.9, 1.5, 1. ],
[5.5, 2.3, 4. , 1.3, 1. ],
[6.5, 2.8, 4.6, 1.5, 1. ],
[5.7, 2.8, 4.5, 1.3, 1. ],
[6.3, 3.3, 4.7, 1.6, 1. ],
[4.9, 2.4, 3.3, 1. , 1. ],
[6.6, 2.9, 4.6, 1.3, 1. ],
[5.2, 2.7, 3.9, 1.4, 1. ],
[5. , 2. , 3.5, 1. , 1. ],
[5.9, 3. , 4.2, 1.5, 1. ],
[6. , 2.2, 4. , 1. , 1. ],
[6.1, 2.9, 4.7, 1.4, 1. ],
[5.6, 2.9, 3.6, 1.3, 1. ],
[6.7, 3.1, 4.4, 1.4, 1. ],
[5.6, 3. , 4.5, 1.5, 1. ],
[5.8, 2.7, 4.1, 1. , 1. ],
[6.2, 2.2, 4.5, 1.5, 1. ],
[5.6, 2.5, 3.9, 1.1, 1. ],
[5.9, 3.2, 4.8, 1.8, 1. ],
[6.1, 2.8, 4. , 1.3, 1. ],
[6.3, 2.5, 4.9, 1.5, 1. ],
[6.1, 2.8, 4.7, 1.2, 1. ],
[6.4, 2.9, 4.3, 1.3, 1. ],
[6.6, 3. , 4.4, 1.4, 1. ],
[6.8, 2.8, 4.8, 1.4, 1. ],
[6.7, 3. , 5. , 1.7, 1. ],
[6. , 2.9, 4.5, 1.5, 1. ],
[5.7, 2.6, 3.5, 1. , 1. ],
[5.5, 2.4, 3.8, 1.1, 1. ],
[5.5, 2.4, 3.7, 1. , 1. ],
[5.8, 2.7, 3.9, 1.2, 1. ],
[6. , 2.7, 5.1, 1.6, 1. ],
[5.4, 3. , 4.5, 1.5, 1. ],
[6. , 3.4, 4.5, 1.6, 1. ],
[6.7, 3.1, 4.7, 1.5, 1. ],
[6.3, 2.3, 4.4, 1.3, 1. ],
[5.6, 3. , 4.1, 1.3, 1. ],
[5.5, 2.5, 4. , 1.3, 1. ],
[5.5, 2.6, 4.4, 1.2, 1. ],
[6.1, 3. , 4.6, 1.4, 1. ],
[5.8, 2.6, 4. , 1.2, 1. ],
[5. , 2.3, 3.3, 1. , 1. ],
[5.6, 2.7, 4.2, 1.3, 1. ],
[5.7, 3. , 4.2, 1.2, 1. ],
[5.7, 2.9, 4.2, 1.3, 1. ],
[6.2, 2.9, 4.3, 1.3, 1. ],
[5.1, 2.5, 3. , 1.1, 1. ],
[5.7, 2.8, 4.1, 1.3, 1. ],
[6.3, 3.3, 6. , 2.5, 2. ],
[5.8, 2.7, 5.1, 1.9, 2. ],
[7.1, 3. , 5.9, 2.1, 2. ],
[6.3, 2.9, 5.6, 1.8, 2. ],
[6.5, 3. , 5.8, 2.2, 2. ],
[7.6, 3. , 6.6, 2.1, 2. ],
[4.9, 2.5, 4.5, 1.7, 2. ],
[7.3, 2.9, 6.3, 1.8, 2. ],
[6.7, 2.5, 5.8, 1.8, 2. ],
[7.2, 3.6, 6.1, 2.5, 2. ],
[6.5, 3.2, 5.1, 2. , 2. ],
[6.4, 2.7, 5.3, 1.9, 2. ],
[6.8, 3. , 5.5, 2.1, 2. ],
[5.7, 2.5, 5. , 2. , 2. ],
[5.8, 2.8, 5.1, 2.4, 2. ],
[6.4, 3.2, 5.3, 2.3, 2. ],
[6.5, 3. , 5.5, 1.8, 2. ],
[7.7, 3.8, 6.7, 2.2, 2. ],
[7.7, 2.6, 6.9, 2.3, 2. ],
[6. , 2.2, 5. , 1.5, 2. ],
[6.9, 3.2, 5.7, 2.3, 2. ],
[5.6, 2.8, 4.9, 2. , 2. ],
[7.7, 2.8, 6.7, 2. , 2. ],
[6.3, 2.7, 4.9, 1.8, 2. ],
[6.7, 3.3, 5.7, 2.1, 2. ],
[7.2, 3.2, 6. , 1.8, 2. ],
[6.2, 2.8, 4.8, 1.8, 2. ],
[6.1, 3. , 4.9, 1.8, 2. ],
[6.4, 2.8, 5.6, 2.1, 2. ],
[7.2, 3. , 5.8, 1.6, 2. ],
[7.4, 2.8, 6.1, 1.9, 2. ],
[7.9, 3.8, 6.4, 2. , 2. ],
[6.4, 2.8, 5.6, 2.2, 2. ],
[6.3, 2.8, 5.1, 1.5, 2. ],
[6.1, 2.6, 5.6, 1.4, 2. ],
[7.7, 3. , 6.1, 2.3, 2. ],
[6.3, 3.4, 5.6, 2.4, 2. ],
[6.4, 3.1, 5.5, 1.8, 2. ],
[6. , 3. , 4.8, 1.8, 2. ],
[6.9, 3.1, 5.4, 2.1, 2. ],
[6.7, 3.1, 5.6, 2.4, 2. ],
[6.9, 3.1, 5.1, 2.3, 2. ],
[5.8, 2.7, 5.1, 1.9, 2. ],
[6.8, 3.2, 5.9, 2.3, 2. ],
[6.7, 3.3, 5.7, 2.5, 2. ],
[6.7, 3. , 5.2, 2.3, 2. ],
[6.3, 2.5, 5. , 1.9, 2. ],
[6.5, 3. , 5.2, 2. , 2. ],
[6.2, 3.4, 5.4, 2.3, 2. ],
[5.9, 3. , 5.1, 1.8, 2. ]])
In [16]:
x, y = np.split(data, (4,), axis=1) # 前四列是x最后一列是y
x
Out[16]:
array([[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5. , 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.7, 1.5, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3. , 1.4, 0.1],
[4.3, 3. , 1.1, 0.1],
[5.8, 4. , 1.2, 0.2],
[5.7, 4.4, 1.5, 0.4],
[5.4, 3.9, 1.3, 0.4],
[5.1, 3.5, 1.4, 0.3],
[5.7, 3.8, 1.7, 0.3],
[5.1, 3.8, 1.5, 0.3],
[5.4, 3.4, 1.7, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.3, 1.7, 0.5],
[4.8, 3.4, 1.9, 0.2],
[5. , 3. , 1.6, 0.2],
[5. , 3.4, 1.6, 0.4],
[5.2, 3.5, 1.5, 0.2],
[5.2, 3.4, 1.4, 0.2],
[4.7, 3.2, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5.2, 4.1, 1.5, 0.1],
[5.5, 4.2, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5. , 3.2, 1.2, 0.2],
[5.5, 3.5, 1.3, 0.2],
[4.9, 3.1, 1.5, 0.1],
[4.4, 3. , 1.3, 0.2],
[5.1, 3.4, 1.5, 0.2],
[5. , 3.5, 1.3, 0.3],
[4.5, 2.3, 1.3, 0.3],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.5, 1.6, 0.6],
[5.1, 3.8, 1.9, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.1, 3.8, 1.6, 0.2],
[4.6, 3.2, 1.4, 0.2],
[5.3, 3.7, 1.5, 0.2],
[5. , 3.3, 1.4, 0.2],
[7. , 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4. , 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1. ],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4],
[5. , 2. , 3.5, 1. ],
[5.9, 3. , 4.2, 1.5],
[6. , 2.2, 4. , 1. ],
[6.1, 2.9, 4.7, 1.4],
[5.6, 2.9, 3.6, 1.3],
[6.7, 3.1, 4.4, 1.4],
[5.6, 3. , 4.5, 1.5],
[5.8, 2.7, 4.1, 1. ],
[6.2, 2.2, 4.5, 1.5],
[5.6, 2.5, 3.9, 1.1],
[5.9, 3.2, 4.8, 1.8],
[6.1, 2.8, 4. , 1.3],
[6.3, 2.5, 4.9, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.4, 2.9, 4.3, 1.3],
[6.6, 3. , 4.4, 1.4],
[6.8, 2.8, 4.8, 1.4],
[6.7, 3. , 5. , 1.7],
[6. , 2.9, 4.5, 1.5],
[5.7, 2.6, 3.5, 1. ],
[5.5, 2.4, 3.8, 1.1],
[5.5, 2.4, 3.7, 1. ],
[5.8, 2.7, 3.9, 1.2],
[6. , 2.7, 5.1, 1.6],
[5.4, 3. , 4.5, 1.5],
[6. , 3.4, 4.5, 1.6],
[6.7, 3.1, 4.7, 1.5],
[6.3, 2.3, 4.4, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.5, 2.5, 4. , 1.3],
[5.5, 2.6, 4.4, 1.2],
[6.1, 3. , 4.6, 1.4],
[5.8, 2.6, 4. , 1.2],
[5. , 2.3, 3.3, 1. ],
[5.6, 2.7, 4.2, 1.3],
[5.7, 3. , 4.2, 1.2],
[5.7, 2.9, 4.2, 1.3],
[6.2, 2.9, 4.3, 1.3],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 5.1, 1.9],
[7.1, 3. , 5.9, 2.1],
[6.3, 2.9, 5.6, 1.8],
[6.5, 3. , 5.8, 2.2],
[7.6, 3. , 6.6, 2.1],
[4.9, 2.5, 4.5, 1.7],
[7.3, 2.9, 6.3, 1.8],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.6, 6.1, 2.5],
[6.5, 3.2, 5.1, 2. ],
[6.4, 2.7, 5.3, 1.9],
[6.8, 3. , 5.5, 2.1],
[5.7, 2.5, 5. , 2. ],
[5.8, 2.8, 5.1, 2.4],
[6.4, 3.2, 5.3, 2.3],
[6.5, 3. , 5.5, 1.8],
[7.7, 3.8, 6.7, 2.2],
[7.7, 2.6, 6.9, 2.3],
[6. , 2.2, 5. , 1.5],
[6.9, 3.2, 5.7, 2.3],
[5.6, 2.8, 4.9, 2. ],
[7.7, 2.8, 6.7, 2. ],
[6.3, 2.7, 4.9, 1.8],
[6.7, 3.3, 5.7, 2.1],
[7.2, 3.2, 6. , 1.8],
[6.2, 2.8, 4.8, 1.8],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.1],
[7.2, 3. , 5.8, 1.6],
[7.4, 2.8, 6.1, 1.9],
[7.9, 3.8, 6.4, 2. ],
[6.4, 2.8, 5.6, 2.2],
[6.3, 2.8, 5.1, 1.5],
[6.1, 2.6, 5.6, 1.4],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.4, 5.6, 2.4],
[6.4, 3.1, 5.5, 1.8],
[6. , 3. , 4.8, 1.8],
[6.9, 3.1, 5.4, 2.1],
[6.7, 3.1, 5.6, 2.4],
[6.9, 3.1, 5.1, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.8, 3.2, 5.9, 2.3],
[6.7, 3.3, 5.7, 2.5],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 5. , 1.9],
[6.5, 3. , 5.2, 2. ],
[6.2, 3.4, 5.4, 2.3],
[5.9, 3. , 5.1, 1.8]])
In [17]:
y
Out[17]:
array([[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.]])
In [18]:
x = x[:, :2] # 只取前两列作为x
x
Out[18]:
array([[5.1, 3.5],
[4.9, 3. ],
[4.7, 3.2],
[4.6, 3.1],
[5. , 3.6],
[5.4, 3.9],
[4.6, 3.4],
[5. , 3.4],
[4.4, 2.9],
[4.9, 3.1],
[5.4, 3.7],
[4.8, 3.4],
[4.8, 3. ],
[4.3, 3. ],
[5.8, 4. ],
[5.7, 4.4],
[5.4, 3.9],
[5.1, 3.5],
[5.7, 3.8],
[5.1, 3.8],
[5.4, 3.4],
[5.1, 3.7],
[4.6, 3.6],
[5.1, 3.3],
[4.8, 3.4],
[5. , 3. ],
[5. , 3.4],
[5.2, 3.5],
[5.2, 3.4],
[4.7, 3.2],
[4.8, 3.1],
[5.4, 3.4],
[5.2, 4.1],
[5.5, 4.2],
[4.9, 3.1],
[5. , 3.2],
[5.5, 3.5],
[4.9, 3.1],
[4.4, 3. ],
[5.1, 3.4],
[5. , 3.5],
[4.5, 2.3],
[4.4, 3.2],
[5. , 3.5],
[5.1, 3.8],
[4.8, 3. ],
[5.1, 3.8],
[4.6, 3.2],
[5.3, 3.7],
[5. , 3.3],
[7. , 3.2],
[6.4, 3.2],
[6.9, 3.1],
[5.5, 2.3],
[6.5, 2.8],
[5.7, 2.8],
[6.3, 3.3],
[4.9, 2.4],
[6.6, 2.9],
[5.2, 2.7],
[5. , 2. ],
[5.9, 3. ],
[6. , 2.2],
[6.1, 2.9],
[5.6, 2.9],
[6.7, 3.1],
[5.6, 3. ],
[5.8, 2.7],
[6.2, 2.2],
[5.6, 2.5],
[5.9, 3.2],
[6.1, 2.8],
[6.3, 2.5],
[6.1, 2.8],
[6.4, 2.9],
[6.6, 3. ],
[6.8, 2.8],
[6.7, 3. ],
[6. , 2.9],
[5.7, 2.6],
[5.5, 2.4],
[5.5, 2.4],
[5.8, 2.7],
[6. , 2.7],
[5.4, 3. ],
[6. , 3.4],
[6.7, 3.1],
[6.3, 2.3],
[5.6, 3. ],
[5.5, 2.5],
[5.5, 2.6],
[6.1, 3. ],
[5.8, 2.6],
[5. , 2.3],
[5.6, 2.7],
[5.7, 3. ],
[5.7, 2.9],
[6.2, 2.9],
[5.1, 2.5],
[5.7, 2.8],
[6.3, 3.3],
[5.8, 2.7],
[7.1, 3. ],
[6.3, 2.9],
[6.5, 3. ],
[7.6, 3. ],
[4.9, 2.5],
[7.3, 2.9],
[6.7, 2.5],
[7.2, 3.6],
[6.5, 3.2],
[6.4, 2.7],
[6.8, 3. ],
[5.7, 2.5],
[5.8, 2.8],
[6.4, 3.2],
[6.5, 3. ],
[7.7, 3.8],
[7.7, 2.6],
[6. , 2.2],
[6.9, 3.2],
[5.6, 2.8],
[7.7, 2.8],
[6.3, 2.7],
[6.7, 3.3],
[7.2, 3.2],
[6.2, 2.8],
[6.1, 3. ],
[6.4, 2.8],
[7.2, 3. ],
[7.4, 2.8],
[7.9, 3.8],
[6.4, 2.8],
[6.3, 2.8],
[6.1, 2.6],
[7.7, 3. ],
[6.3, 3.4],
[6.4, 3.1],
[6. , 3. ],
[6.9, 3.1],
[6.7, 3.1],
[6.9, 3.1],
[5.8, 2.7],
[6.8, 3.2],
[6.7, 3.3],
[6.7, 3. ],
[6.3, 2.5],
[6.5, 3. ],
[6.2, 3.4],
[5.9, 3. ]])
In [19]:
gnb = Pipeline([
('sc', StandardScaler()), # 把数据进行高斯标准化,以0为均值,1为方差
('clf', GaussianNB())]) # 假定数据为高斯分布
In [20]:
y.ravel() # 转化为行向量
Out[20]:
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 2.,
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])
In [21]:
gnb.fit(x, y.ravel())
Out[21]:
Pipeline(memory=None,
steps=[('sc', StandardScaler(copy=True, with_mean=True, with_std=True)), ('clf', GaussianNB(priors=None, var_smoothing=1e-09))])
In [23]:
y_hat = gnb.predict(x)
y_hat
Out[23]:
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 2.,
2., 2., 1., 2., 1., 2., 1., 2., 1., 1., 1., 1., 1., 1., 2., 1., 1.,
1., 1., 1., 1., 1., 1., 2., 2., 2., 2., 1., 1., 1., 1., 1., 1., 1.,
2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1.,
2., 1., 2., 2., 1., 2., 2., 2., 2., 1., 2., 1., 1., 2., 2., 2., 2.,
1., 2., 1., 2., 1., 2., 2., 1., 1., 1., 2., 2., 2., 1., 1., 1., 2.,
2., 2., 1., 2., 2., 2., 1., 2., 2., 2., 1., 2., 2., 1.])
In [24]:
y = y.reshape(-1) # 相当于y.ravel()
y
Out[24]:
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 2.,
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])
In [25]:
result = y_hat == y
result
Out[25]:
array([ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, False, True, True, True,
True, True, True, True, True, False, False, False, True,
False, True, False, True, False, True, True, True, True,
True, True, False, True, True, True, True, True, True,
True, True, False, False, False, False, True, True, True,
True, True, True, True, False, False, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, False, True, False, True, True, False, True,
True, True, True, False, True, False, False, True, True,
True, True, False, True, False, True, False, True, True,
False, False, False, True, True, True, False, False, False,
True, True, True, False, True, True, True, False, True,
True, True, False, True, True, False])
In [27]:
acc = np.mean(result) # 相当于把true当成1,false为0,求平均值,即为准确率
acc
Out[27]:
0.78
以下为版本2
In [29]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.preprocessing import StandardScaler, MinMaxScaler, PolynomialFeatures
from sklearn.naive_bayes import GaussianNB, MultinomialNB #高斯贝叶斯和多项式朴素贝叶斯
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 设置属性防止中文乱码
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False
# 花萼长度、花萼宽度,花瓣长度,花瓣宽度
iris_feature_E = 'sepal length', 'sepal width', 'petal length', 'petal width'
iris_feature_C = u'花萼长度', u'花萼宽度', u'花瓣长度', u'花瓣宽度'
iris_class = 'Iris-setosa', 'Iris-versicolor', 'Iris-virginica'
features = [2, 3]
# 读取数据
path = 'D:\\mlInAction\\8.iris.data' # 数据文件路径
data = pd.read_csv(path, header=None)
data
Out[29]:
| 0 | 1 | 2 | 3 | 4 | |
|---|---|---|---|---|---|
| 0 | 5.1 | 3.5 | 1.4 | 0.2 | Iris-setosa |
| 1 | 4.9 | 3.0 | 1.4 | 0.2 | Iris-setosa |
| 2 | 4.7 | 3.2 | 1.3 | 0.2 | Iris-setosa |
| 3 | 4.6 | 3.1 | 1.5 | 0.2 | Iris-setosa |
| 4 | 5.0 | 3.6 | 1.4 | 0.2 | Iris-setosa |
| 5 | 5.4 | 3.9 | 1.7 | 0.4 | Iris-setosa |
| 6 | 4.6 | 3.4 | 1.4 | 0.3 | Iris-setosa |
| 7 | 5.0 | 3.4 | 1.5 | 0.2 | Iris-setosa |
| 8 | 4.4 | 2.9 | 1.4 | 0.2 | Iris-setosa |
| 9 | 4.9 | 3.1 | 1.5 | 0.1 | Iris-setosa |
| 10 | 5.4 | 3.7 | 1.5 | 0.2 | Iris-setosa |
| 11 | 4.8 | 3.4 | 1.6 | 0.2 | Iris-setosa |
| 12 | 4.8 | 3.0 | 1.4 | 0.1 | Iris-setosa |
| 13 | 4.3 | 3.0 | 1.1 | 0.1 | Iris-setosa |
| 14 | 5.8 | 4.0 | 1.2 | 0.2 | Iris-setosa |
| 15 | 5.7 | 4.4 | 1.5 | 0.4 | Iris-setosa |
| 16 | 5.4 | 3.9 | 1.3 | 0.4 | Iris-setosa |
| 17 | 5.1 | 3.5 | 1.4 | 0.3 | Iris-setosa |
| 18 | 5.7 | 3.8 | 1.7 | 0.3 | Iris-setosa |
| 19 | 5.1 | 3.8 | 1.5 | 0.3 | Iris-setosa |
| 20 | 5.4 | 3.4 | 1.7 | 0.2 | Iris-setosa |
| 21 | 5.1 | 3.7 | 1.5 | 0.4 | Iris-setosa |
| 22 | 4.6 | 3.6 | 1.0 | 0.2 | Iris-setosa |
| 23 | 5.1 | 3.3 | 1.7 | 0.5 | Iris-setosa |
| 24 | 4.8 | 3.4 | 1.9 | 0.2 | Iris-setosa |
| 25 | 5.0 | 3.0 | 1.6 | 0.2 | Iris-setosa |
| 26 | 5.0 | 3.4 | 1.6 | 0.4 | Iris-setosa |
| 27 | 5.2 | 3.5 | 1.5 | 0.2 | Iris-setosa |
| 28 | 5.2 | 3.4 | 1.4 | 0.2 | Iris-setosa |
| 29 | 4.7 | 3.2 | 1.6 | 0.2 | Iris-setosa |
| ... | ... | ... | ... | ... | ... |
| 120 | 6.9 | 3.2 | 5.7 | 2.3 | Iris-virginica |
| 121 | 5.6 | 2.8 | 4.9 | 2.0 | Iris-virginica |
| 122 | 7.7 | 2.8 | 6.7 | 2.0 | Iris-virginica |
| 123 | 6.3 | 2.7 | 4.9 | 1.8 | Iris-virginica |
| 124 | 6.7 | 3.3 | 5.7 | 2.1 | Iris-virginica |
| 125 | 7.2 | 3.2 | 6.0 | 1.8 | Iris-virginica |
| 126 | 6.2 | 2.8 | 4.8 | 1.8 | Iris-virginica |
| 127 | 6.1 | 3.0 | 4.9 | 1.8 | Iris-virginica |
| 128 | 6.4 | 2.8 | 5.6 | 2.1 | Iris-virginica |
| 129 | 7.2 | 3.0 | 5.8 | 1.6 | Iris-virginica |
| 130 | 7.4 | 2.8 | 6.1 | 1.9 | Iris-virginica |
| 131 | 7.9 | 3.8 | 6.4 | 2.0 | Iris-virginica |
| 132 | 6.4 | 2.8 | 5.6 | 2.2 | Iris-virginica |
| 133 | 6.3 | 2.8 | 5.1 | 1.5 | Iris-virginica |
| 134 | 6.1 | 2.6 | 5.6 | 1.4 | Iris-virginica |
| 135 | 7.7 | 3.0 | 6.1 | 2.3 | Iris-virginica |
| 136 | 6.3 | 3.4 | 5.6 | 2.4 | Iris-virginica |
| 137 | 6.4 | 3.1 | 5.5 | 1.8 | Iris-virginica |
| 138 | 6.0 | 3.0 | 4.8 | 1.8 | Iris-virginica |
| 139 | 6.9 | 3.1 | 5.4 | 2.1 | Iris-virginica |
| 140 | 6.7 | 3.1 | 5.6 | 2.4 | Iris-virginica |
| 141 | 6.9 | 3.1 | 5.1 | 2.3 | Iris-virginica |
| 142 | 5.8 | 2.7 | 5.1 | 1.9 | Iris-virginica |
| 143 | 6.8 | 3.2 | 5.9 | 2.3 | Iris-virginica |
| 144 | 6.7 | 3.3 | 5.7 | 2.5 | Iris-virginica |
| 145 | 6.7 | 3.0 | 5.2 | 2.3 | Iris-virginica |
| 146 | 6.3 | 2.5 | 5.0 | 1.9 | Iris-virginica |
| 147 | 6.5 | 3.0 | 5.2 | 2.0 | Iris-virginica |
| 148 | 6.2 | 3.4 | 5.4 | 2.3 | Iris-virginica |
| 149 | 5.9 | 3.0 | 5.1 | 1.8 | Iris-virginica |
150 rows × 5 columns
In [35]:
x = data[list(range(4))] # 此处为pd,不能用切片
x
Out[35]:
| 0 | 1 | 2 | 3 | |
|---|---|---|---|---|
| 0 | 5.1 | 3.5 | 1.4 | 0.2 |
| 1 | 4.9 | 3.0 | 1.4 | 0.2 |
| 2 | 4.7 | 3.2 | 1.3 | 0.2 |
| 3 | 4.6 | 3.1 | 1.5 | 0.2 |
| 4 | 5.0 | 3.6 | 1.4 | 0.2 |
| 5 | 5.4 | 3.9 | 1.7 | 0.4 |
| 6 | 4.6 | 3.4 | 1.4 | 0.3 |
| 7 | 5.0 | 3.4 | 1.5 | 0.2 |
| 8 | 4.4 | 2.9 | 1.4 | 0.2 |
| 9 | 4.9 | 3.1 | 1.5 | 0.1 |
| 10 | 5.4 | 3.7 | 1.5 | 0.2 |
| 11 | 4.8 | 3.4 | 1.6 | 0.2 |
| 12 | 4.8 | 3.0 | 1.4 | 0.1 |
| 13 | 4.3 | 3.0 | 1.1 | 0.1 |
| 14 | 5.8 | 4.0 | 1.2 | 0.2 |
| 15 | 5.7 | 4.4 | 1.5 | 0.4 |
| 16 | 5.4 | 3.9 | 1.3 | 0.4 |
| 17 | 5.1 | 3.5 | 1.4 | 0.3 |
| 18 | 5.7 | 3.8 | 1.7 | 0.3 |
| 19 | 5.1 | 3.8 | 1.5 | 0.3 |
| 20 | 5.4 | 3.4 | 1.7 | 0.2 |
| 21 | 5.1 | 3.7 | 1.5 | 0.4 |
| 22 | 4.6 | 3.6 | 1.0 | 0.2 |
| 23 | 5.1 | 3.3 | 1.7 | 0.5 |
| 24 | 4.8 | 3.4 | 1.9 | 0.2 |
| 25 | 5.0 | 3.0 | 1.6 | 0.2 |
| 26 | 5.0 | 3.4 | 1.6 | 0.4 |
| 27 | 5.2 | 3.5 | 1.5 | 0.2 |
| 28 | 5.2 | 3.4 | 1.4 | 0.2 |
| 29 | 4.7 | 3.2 | 1.6 | 0.2 |
| ... | ... | ... | ... | ... |
| 120 | 6.9 | 3.2 | 5.7 | 2.3 |
| 121 | 5.6 | 2.8 | 4.9 | 2.0 |
| 122 | 7.7 | 2.8 | 6.7 | 2.0 |
| 123 | 6.3 | 2.7 | 4.9 | 1.8 |
| 124 | 6.7 | 3.3 | 5.7 | 2.1 |
| 125 | 7.2 | 3.2 | 6.0 | 1.8 |
| 126 | 6.2 | 2.8 | 4.8 | 1.8 |
| 127 | 6.1 | 3.0 | 4.9 | 1.8 |
| 128 | 6.4 | 2.8 | 5.6 | 2.1 |
| 129 | 7.2 | 3.0 | 5.8 | 1.6 |
| 130 | 7.4 | 2.8 | 6.1 | 1.9 |
| 131 | 7.9 | 3.8 | 6.4 | 2.0 |
| 132 | 6.4 | 2.8 | 5.6 | 2.2 |
| 133 | 6.3 | 2.8 | 5.1 | 1.5 |
| 134 | 6.1 | 2.6 | 5.6 | 1.4 |
| 135 | 7.7 | 3.0 | 6.1 | 2.3 |
| 136 | 6.3 | 3.4 | 5.6 | 2.4 |
| 137 | 6.4 | 3.1 | 5.5 | 1.8 |
| 138 | 6.0 | 3.0 | 4.8 | 1.8 |
| 139 | 6.9 | 3.1 | 5.4 | 2.1 |
| 140 | 6.7 | 3.1 | 5.6 | 2.4 |
| 141 | 6.9 | 3.1 | 5.1 | 2.3 |
| 142 | 5.8 | 2.7 | 5.1 | 1.9 |
| 143 | 6.8 | 3.2 | 5.9 | 2.3 |
| 144 | 6.7 | 3.3 | 5.7 | 2.5 |
| 145 | 6.7 | 3.0 | 5.2 | 2.3 |
| 146 | 6.3 | 2.5 | 5.0 | 1.9 |
| 147 | 6.5 | 3.0 | 5.2 | 2.0 |
| 148 | 6.2 | 3.4 | 5.4 | 2.3 |
| 149 | 5.9 | 3.0 | 5.1 | 1.8 |
150 rows × 4 columns
In [36]:
x = x[features]
x
Out[36]:
| 2 | 3 | |
|---|---|---|
| 0 | 1.4 | 0.2 |
| 1 | 1.4 | 0.2 |
| 2 | 1.3 | 0.2 |
| 3 | 1.5 | 0.2 |
| 4 | 1.4 | 0.2 |
| 5 | 1.7 | 0.4 |
| 6 | 1.4 | 0.3 |
| 7 | 1.5 | 0.2 |
| 8 | 1.4 | 0.2 |
| 9 | 1.5 | 0.1 |
| 10 | 1.5 | 0.2 |
| 11 | 1.6 | 0.2 |
| 12 | 1.4 | 0.1 |
| 13 | 1.1 | 0.1 |
| 14 | 1.2 | 0.2 |
| 15 | 1.5 | 0.4 |
| 16 | 1.3 | 0.4 |
| 17 | 1.4 | 0.3 |
| 18 | 1.7 | 0.3 |
| 19 | 1.5 | 0.3 |
| 20 | 1.7 | 0.2 |
| 21 | 1.5 | 0.4 |
| 22 | 1.0 | 0.2 |
| 23 | 1.7 | 0.5 |
| 24 | 1.9 | 0.2 |
| 25 | 1.6 | 0.2 |
| 26 | 1.6 | 0.4 |
| 27 | 1.5 | 0.2 |
| 28 | 1.4 | 0.2 |
| 29 | 1.6 | 0.2 |
| ... | ... | ... |
| 120 | 5.7 | 2.3 |
| 121 | 4.9 | 2.0 |
| 122 | 6.7 | 2.0 |
| 123 | 4.9 | 1.8 |
| 124 | 5.7 | 2.1 |
| 125 | 6.0 | 1.8 |
| 126 | 4.8 | 1.8 |
| 127 | 4.9 | 1.8 |
| 128 | 5.6 | 2.1 |
| 129 | 5.8 | 1.6 |
| 130 | 6.1 | 1.9 |
| 131 | 6.4 | 2.0 |
| 132 | 5.6 | 2.2 |
| 133 | 5.1 | 1.5 |
| 134 | 5.6 | 1.4 |
| 135 | 6.1 | 2.3 |
| 136 | 5.6 | 2.4 |
| 137 | 5.5 | 1.8 |
| 138 | 4.8 | 1.8 |
| 139 | 5.4 | 2.1 |
| 140 | 5.6 | 2.4 |
| 141 | 5.1 | 2.3 |
| 142 | 5.1 | 1.9 |
| 143 | 5.9 | 2.3 |
| 144 | 5.7 | 2.5 |
| 145 | 5.2 | 2.3 |
| 146 | 5.0 | 1.9 |
| 147 | 5.2 | 2.0 |
| 148 | 5.4 | 2.3 |
| 149 | 5.1 | 1.8 |
150 rows × 2 columns
In [37]:
y = pd.Categorical(data[4]).codes # 直接将数据特征转换为0,1,2
y
Out[37]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int8)
In [38]:
print("总样本数目:%d;特征属性数目:%d" % x.shape)
总样本数目:150;特征属性数目:2
In [40]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=14)
print("训练数据集样本数目:%d, 测试数据集样本数目:%d" % (x_train.shape[0], x_test.shape[0]))
训练数据集样本数目:120, 测试数据集样本数目:30
In [41]:
clf = Pipeline([
('sc', StandardScaler()), # 标准化,把它转化成了高斯分布
('poly', PolynomialFeatures(degree=1)),
('clf', GaussianNB())]) # MultinomialNB多项式贝叶斯算法中要求特征属性的取值不能为负数
# 训练模型
clf.fit(x_train, y_train)
Out[41]:
Pipeline(memory=None,
steps=[('sc', StandardScaler(copy=True, with_mean=True, with_std=True)), ('poly', PolynomialFeatures(degree=1, include_bias=True, interaction_only=False)), ('clf', GaussianNB(priors=None, var_smoothing=1e-09))])
In [42]:
y_train_hat = clf.predict(x_train)
print('训练集准确度: %.2f%%' % (100 * accuracy_score(y_train, y_train_hat)))
y_test_hat = clf.predict(x_test)
print('测试集准确度:%.2f%%' % (100 * accuracy_score(y_test, y_test_hat)))
训练集准确度: 95.83% 测试集准确度:96.67%
浙公网安备 33010602011771号