鸢尾花数据集----决策树&&神经网络
为方便理解两种不同预测分类算法 我们均调用 sklearn 里 datasets 的鸢尾花数据集
决策树1(复杂):
1 import numpy as np 2 from sklearn import datasets 3 from sklearn.model_selection import train_test_split 4 import matplotlib as mpl 5 import matplotlib.pyplot as plt 6 from sklearn import tree 7 from sklearn.pipeline import Pipeline 8 from sklearn.tree import DecisionTreeClassifier 9 from sklearn.preprocessing import StandardScaler 10 11 # 防止画图汉字乱码 12 mpl.rcParams['font.sans-serif'] = [u'SimHei'] 13 mpl.rcParams['axes.unicode_minus'] = False 14 15 #数据准备 16 dataset = datasets.load_iris() # 此时 训练数据(train)与标签(target) 已经分离 为 字典 数据集 17 # 数据集 已经将标签数据化(化为0-2标签值) 无需再处理 18 19 data = dataset['data'] # 取出对应键 的值 值为array类型 20 target = dataset['target'] 21 # input = torch.FloatTensor(dataset['data']) 22 # y = torch.LongTensor(dataset['target']) 23 24 x = np.array(data) 25 y = np.array(target) 26 x = x[:, :2] # 此时的数据为 150行 4列 为方便画图 我们只取前两个特征 27 # 将数据集 7 / 3 分 28 x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=1) 29 30 model = Pipeline([ 31 ('ss', StandardScaler()), 32 ('DTC', DecisionTreeClassifier(criterion='entropy', max_depth=3))]) 33 # clf = DecisionTreeClassifier(criterion='entropy', max_depth=3) 34 model = model.fit(x_train, y_train) 35 y_test_hat = model.predict(x_test) # 测试数据 y_test_hat 为预测值 36 # print(y_test) 45个预测样本的真实标签 37 # [0 1 1 0 2 1 2 0 0 2 1 0 2 1 1 0 1 1 0 0 1 1 1 0 2 1 0 0 1 2 1 2 1 2 2 0 1 0 1 2 2 0 2 2 1] 38 # print(y_test_hat) 45个预测样本的预测标签 39 # [0 1 2 0 2 2 2 0 0 2 1 0 2 2 1 0 1 1 0 0 1 0 2 0 2 1 0 0 1 2 1 2 1 2 1 0 1 0 2 2 2 0 1 2 2] 40 41 42 # 保存 43 # dot -Tpng -o 1.png 1.dot 44 f = open('.\\iris_tree.dot', 'w') 45 tree.export_graphviz(model.get_params('DTC')['DTC'], out_file=f) 46 47 # 画图 48 N, M = 100, 100 # 横纵各采样多少个值 49 x1_min, x1_max = x[:, 0].min(), x[:, 0].max() # 第0列的范围 50 x2_min, x2_max = x[:, 1].min(), x[:, 1].max() # 第1列的范围 51 t1 = np.linspace(x1_min, x1_max, N) 52 t2 = np.linspace(x2_min, x2_max, M) 53 x1, x2 = np.meshgrid(t1, t2) # 生成 v 网格采样点 54 x_show = np.stack((x1.flat, x2.flat), axis=1) # 测试点 55 56 # # 无意义,只是为了凑另外两个维度 57 # # 打开该注释前,确保注释掉x = x[:, :2] 58 # x3 = np.ones(x1.size) * np.average(x[:, 2]) 59 # x4 = np.ones(x1.size) * np.average(x[:, 3]) 60 # x_test = np.stack((x1.flat, x2.flat, x3, x4), axis=1) # 测试点 61 62 cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF']) 63 cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b']) 64 y_show_hat = model.predict(x_show) # 预测值 预测的标签值 65 66 y_show_hat = y_show_hat.reshape(x1.shape) # 使之与输入的形状相同 67 plt.figure(facecolor='w') 68 plt.pcolormesh(x1, x2, y_show_hat, cmap=cm_light) # 预测值的显示 69 plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test.ravel(), edgecolors='k', s=100, cmap=cm_dark, marker='o') # 测试数据 70 plt.scatter(x[:, 0], x[:, 1], c=y.ravel(), edgecolors='k', s=40, cmap=cm_dark) # 全部数据 71 plt.xlabel("花萼长度", fontsize=15) # 花萼长度、花萼宽度 72 plt.ylabel("花萼宽度", fontsize=15) 73 plt.xlim(x1_min, x1_max) 74 plt.ylim(x2_min, x2_max) 75 plt.grid(True) 76 plt.title(u'鸢尾花数据的决策树分类', fontsize=17) 77 plt.show() 78 79 # 训练集上的预测结果 80 y_test = y_test.reshape(-1) 81 82 result = (y_test_hat == y_test) # True则预测正确,False则预测错误 83 acc = np.mean(result) 84 print('准确度: %.2f%%' % (100 * acc)) 85 86 # 过拟合:错误率 87 depth = np.arange(1, 45) 88 err_list = [] 89 for d in depth: # 进行15 90 clf = DecisionTreeClassifier(criterion='entropy', max_depth=d) 91 clf = clf.fit(x_train, y_train) 92 y_test_hat = clf.predict(x_test) # 测试数据 93 result = (y_test_hat == y_test) # True则预测正确,False则预测错误 94 err = 1 - np.mean(result) 95 err_list.append(err) 96 print(d, ' 准确度: %.2f%%' % (100 * err)) 97 plt.figure(facecolor='w') 98 plt.plot(depth, err_list, 'ro-', lw=2) 99 plt.xlabel(u'决策树深度', fontsize=15) 100 plt.ylabel(u'错误率', fontsize=15) 101 plt.title(u'决策树深度与过拟合', fontsize=17) 102 plt.grid(True) 103 104 plt.show() 105 106 from sklearn import tree # 需要导入的包 107 108 f = open('D:\\py_project\\iris_tree.dot', 'w') 109 110 tree.export_graphviz(model.get_params('DTC')['DTC'], out_file=f)


决策树2:
数据集为本地导入与 from sklearn import datasets 数据集一样
1 import numpy as np 2 from sklearn.model_selection import train_test_split 3 from sklearn import tree 4 with open(r'D:\py_project\8.iris.txt', "r", encoding='UTF-8') as fp: 5 data = fp.read().splitlines() 6 lit = [] 7 for str in data: 8 str = str.split(',', 5) 9 lit.append(str) 10 feature = np.array(lit) 11 lable = [] 12 for i in feature: 13 lable.append(i[4]) 14 15 X = feature[:, 0:4] 16 X = np.array(X, dtype=float) 17 print(X) 18 19 def iris_type(lable): 20 it = {'Iris-setosa':0, 21 'Iris-versicolor':1, 22 'Iris-virginica':2} 23 Lable = [] 24 for i in lable: 25 Lable.append(it[i]) 26 27 return Lable 28 lable = iris_type(lable) 29 Y = np.array(lable) 30 31 x_train,x_test,y_train,y_test = train_test_split(X, Y, train_size=0.7) 32 clf = tree.DecisionTreeClassifier().fit(x_train,y_train) 33 y_test_hat = clf.predict(x_test) 34 count = len(y_test) 35 err = 0 36 for i in range(count): 37 if y_test[i] != y_test_hat[i]: 38 err += 1 39 40 print("正确率ACC:",float((count-err)/count))
神经网络:
1 import numpy as np 2 from collections import Counter 3 from sklearn import datasets 4 import torch.nn.functional as Fun 5 from torch.autograd import Variable 6 import matplotlib.pyplot as plt 7 import torch 8 9 dataset = datasets.load_iris() 10 dataut=dataset['data'] 11 priciple=dataset['target'] 12 13 input=torch.FloatTensor(dataset['data']) 14 label=torch.LongTensor(dataset['target']) 15 16 #定义BP神经网络 17 class Net(torch.nn.Module): 18 def __init__(self, n_feature, n_hidden, n_output): 19 super(Net, self).__init__() 20 self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer 21 self.out = torch.nn.Linear(n_hidden, n_output) # output layer 22 23 def forward(self, x): 24 x = Fun.relu(self.hidden(x)) # activation function for hidden layer we choose sigmoid 25 x = self.out(x) 26 return x 27 28 net = Net(n_feature=4, n_hidden=20, n_output=3) 29 optimizer = torch.optim.SGD(net.parameters(), lr=0.02) #SGD: 随机梯度下降 30 loss_func = torch.nn.CrossEntropyLoss() #针对分类问题的损失函数! 31 32 #训练数据 33 for t in range(500): 34 out = net(input) # input x and predict based on x 35 loss = loss_func(out, label) # 输出与label对比 36 optimizer.zero_grad() # clear gradients for next train 37 loss.backward() # backpropagation, compute gradients 38 optimizer.step() # apply gradients 39 40 out = net(input) #out是一个计算矩阵,可以用Fun.softmax(out)转化为概率矩阵 41 prediction = torch.max(out, 1)[1] # 1返回index 0返回原值 42 pred_y = prediction.data.numpy() 43 target_y = label.data.numpy() 44 accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size) 45 print("莺尾花预测准确率",accuracy)
鸢尾花数据集:
共150个分为 三种类别 setosa,versicolor,virginnica
花萼长度、花萼宽度,花瓣长度,花瓣宽度,种类
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica

浙公网安备 33010602011771号