决策树
# 生成决策树
from math import log
import operator
import pickle
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
def calcShannonEnt(data_set):
num_entries = len(data_set) # 行数
label_counts = {}
for feat_vec in data_set: # 循环每行
current_label = feat_vec[-1] # 矩阵最后一列为标签
if current_label not in label_counts.keys(): # 若字典内没标签
label_counts[current_label] = 0 # 将标签:0,以key:value 形式存进字典中
label_counts[current_label] += 1 # 字典内有该标签,则该标签count+1
shannon_ent = 0.0
for key in label_counts:
prob = float(label_counts[key]) / num_entries # 求出每个标签的概率
shannon_ent -= prob * log(prob, 2) # 熵 = 所有元素的-(p(x1)log(x1,2))求和
return shannon_ent # 返回熵
def splitDataSet(data_set, axis, value):
'''
涮选出根据axis与value划分后的列表
:param data_set: 数据集
:param axis: 根据axis去除列
:param value: 根据value去除行
:return:
'''
ret_data_set = [] # 创建新列表
for feat_vec in data_set: # 循环每行
if feat_vec[axis] == value:
reduced_feat_vec = feat_vec[:axis]
reduced_feat_vec.extend(feat_vec[axis + 1:]) # 这两步等于删除feat_vec[axis],但是为了不影响原始数据,所以分两步切片合并
ret_data_set.append(reduced_feat_vec)
return ret_data_set # 返回划分后的数据集
def chooseBestFeatureToSplit(data_set): # 找到最大熵对应的索引
num_features = len(data_set[0]) - 1 # data_set最后一个元素已用
base_entropy = calcShannonEnt(data_set) # 整个数据集的熵,保存最初的无序度量值
best_info_gain = 0.0
best_feature = -1
for i in range(num_features):
feat_list = [example[i] for example in data_set] # 利用列表生成式产生新列表
unique_vals = set(feat_list) # 利用集合去重
new_entropy = 0.0
for value in unique_vals: # 遍历去重后的唯一属性值
sub_data_set = splitDataSet(data_set, i, value) # 已唯一属性值为value划分出数据集
prob = len(sub_data_set) / float(len(data_set)) # 计算该数据集的概率
new_entropy += prob * calcShannonEnt(sub_data_set) # 计算熵
info_gain = base_entropy - new_entropy
if info_gain > best_info_gain: # 判断是否大于最佳熵
best_info_gain = info_gain
best_feature = i
return best_feature # 返回最大熵对应的索引
def majorityCnt(class_list):
class_count = {}
for vote in class_list:
if vote not in class_count.keys():
class_count[vote] = 0
class_count[vote] += 1
sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
return sorted_class_count[0][0] # 返回次数最多的标签
def createTree(data_set, labels):
class_list = [example[-1] for example in data_set] # 拿到各个标签的列表
if class_list.count(class_list[0]) == len(class_list): # 所有类标签完全相同时
return class_list[0] # 返回类标签,return是结束本次递归循环并返回值
if len(data_set[0]) == 1: # 遍历玩所有标签,仍然没有把数据集划分为唯一类别
return majorityCnt(class_list) # 返回次数最多的标签
best_feat = chooseBestFeatureToSplit(data_set) # 找到最大熵对应的索引
best_feat_label = labels[best_feat] # 最合适的标签
my_tree = {best_feat_label: {}} # 创建tree字典
del (labels[best_feat])
feat_values = [example[best_feat] for example in data_set] # 得到最大熵对应列的所有属性值
unique_vals = set(feat_values) # 去除相同属性值
for value in unique_vals:
sub_labels = labels[:] # 保证for期间
my_tree[best_feat_label][value] = createTree(splitDataSet(data_set, best_feat, value),
sub_labels) # 给字典增加键为value,值为返回值的新键值
return my_tree
def classify(input_tree, feat_tabels, test_vec):
first_str = list(input_tree.keys())[0] # 决策树顶端key
second_dict = input_tree[first_str] # 第二层字典
feat_index = feat_tabels.index(first_str) # 最佳分类对应的索引
for key in second_dict.keys():
if test_vec[feat_index] == key:
if type(second_dict[key]).__name__ == 'dict':
class_label = classify(second_dict[key], feat_tabels, test_vec)
else:
class_label = second_dict[key]
return class_label
def storeTree(input_tree, filename):
with open(filename, 'wb') as fw:
pickle.dump(input_tree, fw)
fw.close()
def grabTree(filename):
with open(filename, 'rb') as fr:
return pickle.load(fr)
绘制决策树
import matplotlib.pyplot as plt
# 把决策树的字典放入createPlot(dict) 即可
decision_node = dict(boxstyle='sawtooth', fc='0.8') # 填文本底色
leaf_node = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')
def plotNode(node_text, center_pt, parent_pt, node_type): # 绘制带箭头的注解
createPlot.ax1.annotate(node_text, xy=parent_pt, xycoords='axes fraction',
xytext=center_pt, textcoords='axes fraction', va='center',
ha='center', bbox=node_type, arrowprops=arrow_args)
def getNumLeafs(my_tree): # 找出决策树中没分叉的数量
num_leafs = 0
first_str = list(my_tree.keys())[0] # 取出决策树字典中的最佳值
second_dict = my_tree[first_str] # 取出第二个字典
for key in second_dict.keys(): # 遍历字典的key
if type(second_dict[key]).__name__ == 'dict': # 如果key对应的value是字典,则递归继续分
num_leafs += getNumLeafs(second_dict[key])
else:
num_leafs += 1 # 不是则+1
return num_leafs
def getTreeDepth(my_tree): # 找出决策树的层数,第一层不算
max_depth = 0
first_str = list(my_tree.keys())[0] # 取出决策树字典中的最佳值
second_dict = my_tree[first_str] # 取出第二个字典
for key in second_dict.keys(): # 遍历字典的key
if type(second_dict[key]).__name__ == 'dict':
this_depth = 1 + getTreeDepth(second_dict[key])
else:
this_depth = 1
if this_depth > max_depth:
max_depth = this_depth
return max_depth
def retrieveTree(i):
list_of_trees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yse'}}, 1: 'no'}}}}]
return list_of_trees[i]
def plotMidText(cntr_pt, parent_pt, txt_string): # 在两个节点之间的线上写字
xMid = (parent_pt[0] - cntr_pt[0]) / 2.0 + cntr_pt[0]
yMid = (parent_pt[1] - cntr_pt[1]) / 2 + cntr_pt[1]
createPlot.ax1.text(xMid, yMid, txt_string)
def createPlot(in_tree):
fig = plt.figure(1, facecolor='white')
fig.clf() # 清空绘图区
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
global plotTree_w
global plotTree_d
plotTree_w = float(getNumLeafs(in_tree))
plotTree_d = float(getTreeDepth(in_tree))
plotTree.xOff = -0.5 / plotTree_w
plotTree.yOff = 1.0
plotTree(in_tree, (0.5, 1.0), '')
plt.show()
def plotTree(my_tree, parent_pt, node_txt): # 画树
num_leafs = getNumLeafs(my_tree)
depth = getTreeDepth(my_tree)
first_str = list(my_tree.keys())[0]
cntr_pt = (plotTree.xOff + (1.0 + float(num_leafs)) / 2.0 / plotTree_w, plotTree.yOff)
plotMidText(cntr_pt, parent_pt, node_txt)
plotNode(first_str, cntr_pt, parent_pt, decision_node)
second_dict = my_tree[first_str]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree_d # 减少y的值,将树的总深度平分,每次减少移动一点(向下,因为树是自顶向下画的)
for key in second_dict.keys():
if type(second_dict[key]).__name__ == 'dict':
plotTree(second_dict[key], cntr_pt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree_w
plotNode(second_dict[key], (plotTree.xOff, plotTree.yOff), cntr_pt, leaf_node)
plotMidText((plotTree.xOff, plotTree.yOff), cntr_pt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree_d

浙公网安备 33010602011771号