Python 决策树算法原理及实现
决策树是机器学习常见的算法,最著名的代表是ID3、C4.5和CART算法。也有其他的如:OC1算法、ID4h和ITI。他们如何选择最优划分属性不尽相同,ID3决策树是以信息增益为准则来选择划分属性的,C4.5决策树是以信息增益率来选择最优划分属性的,CART是以基尼指数来选择划分属性的,其他决策树算法还没有弄清楚。
(1)信息论中的相关概念
熵:熵原本是物理学中的一个定义,后来香农将其引申到了信息论领域,用来表示信息量的大小。信息量越大(分类越不“纯净”),对应的熵值就越大,反之亦然。信息熵的计算公式如下:

在实际应用中,会将概率𝑝_𝑘的值用经验概率替换,所以经验信息熵可以表示为:

条件熵:

信息增益:

信息增益率:

基尼指数:

(二)代码实现
1 from math import log 2 import operator 3 4 def createDataSet(): 5 """创建数据集""" 6 dataSet = [[1, 1, 'yes'], 7 [1, 1, 'yes'], 8 [1, 0, 'no'], 9 [0, 1, 'no'], 10 [0, 1, 'no']] 11 labels = ['no surfacing','flippers'] #标签类别 12 return dataSet, labels
1 def calcShannonEnt(dataSet): 2 """计算信息熵""" 3 numEntries = len(dataSet) 4 labelCounts = {} 5 for featVec in dataSet: #the the number of unique elements and their occurance,遍历一次得到dataset的一行数据 6 currentLabel = featVec[-1] 7 if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 8 labelCounts[currentLabel] += 1 9 shannonEnt = 0.0 10 for key in labelCounts: 11 prob = float(labelCounts[key])/numEntries 12 shannonEnt -= prob * log(prob,2) #log base 2 13 return shannonEnt
1 def splitDataSet(dataSet, axis, value): 2 """返回axis特征下==value的数据样本 3 ,axis为数据集第几个特征 4 ,value为特征下的某个值""" 5 retDataSet = [] 6 for featVec in dataSet: 7 if featVec[axis] == value: 8 reducedFeatVec = featVec[:axis] #chop out axis used for splitting 9 reducedFeatVec.extend(featVec[axis+1:]) 10 retDataSet.append(reducedFeatVec) 11 return retDataSet
1 def chooseBestFeatureToSplit(dataSet): 2 """以信息增益度的划分方法来选择最优属性""" 3 numFeatures = len(dataSet[0]) - 1 #the last column is used for the labels 4 baseEntropy = calcShannonEnt(dataSet) 5 bestInfoGain = 0.0; bestFeature = -1 6 for i in range(numFeatures): #iterate over all the features 7 featList = [example[i] for example in dataSet]#create a list of all the examples of this feature 8 uniqueVals = set(featList) #get a set of unique values 9 newEntropy = 0.0 10 for value in uniqueVals: 11 subDataSet = splitDataSet(dataSet, i, value) 12 prob = len(subDataSet)/float(len(dataSet)) 13 newEntropy += prob * calcShannonEnt(subDataSet) 14 infoGain = baseEntropy - newEntropy #calculate the info gain; ie reduction in entropy 15 if (infoGain > bestInfoGain): #compare this to the best gain so far 16 bestInfoGain = infoGain #if better than current best, set to best 17 bestFeature = i 18 return bestFeature #returns an integer
1 def majorityCnt(classList): 2 """用来筛选数据最后一个特征时的划分,即投票表决法""" 3 classCount={} 4 for vote in classList: 5 if vote not in classCount.keys(): classCount[vote] = 0 6 classCount[vote] += 1 7 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 8 return sortedClassCount[0][0]
1 def createTree(dataSet,labels): 2 """生成决策树""" 3 classList = [example[-1] for example in dataSet] 4 5 """迭代终止条件""" 6 if classList.count(classList[0]) == len(classList): 7 return classList[0]#stop splitting when all of the classes are equal 8 if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet 9 return majorityCnt(classList) 10 11 bestFeat = chooseBestFeatureToSplit(dataSet) 12 bestFeatLabel = labels[bestFeat] 13 myTree = {bestFeatLabel:{}} 14 del(labels[bestFeat]) 15 featValues = [example[bestFeat] for example in dataSet] 16 uniqueVals = set(featValues) 17 for value in uniqueVals: 18 subLabels = labels[:] #copy all of labels, so trees don't mess up existing labels 19 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels) 20 return myTree
1 if __name__=='__main__': 2 """代码测试""" 3 dataset ,labels = createDataSet() 4 s = calcShannonEnt(dataset) 5 retdataset = splitDataSet(dataSet, 1, 1) 6 bestFeature = chooseBestFeatureToSplit(dataSet) 7 myTree = createTree(dataSet,labels) 8 print(myTree)
结果:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
二、生成树可视化
1 import matplotlib.pyplot as plt 2 3 decisionNode = dict(boxstyle="sawtooth", fc="0.8") 4 leafNode = dict(boxstyle="round4", fc="0.8") 5 arrow_args = dict(arrowstyle="<-") 6 7 def getNumLeafs(myTree): 8 numLeafs = 0 9 # firstStr = myTree.keys() 10 firstSides = list(myTree.keys()) 11 firstStr = firstSides[0]#找到输入的第一个元素 12 secondDict = myTree[firstStr] 13 for key in secondDict.keys(): 14 if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes 15 numLeafs += getNumLeafs(secondDict[key]) 16 else: numLeafs +=1 17 return numLeafs 18 19 def getTreeDepth(myTree): 20 maxDepth = 0 21 # firstStr = myTree.keys()[0] 22 firstSides = list(myTree.keys()) 23 firstStr = firstSides[0]#找到输入的第一个元素 24 secondDict = myTree[firstStr] 25 for key in secondDict.keys(): 26 if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes 27 thisDepth = 1 + getTreeDepth(secondDict[key]) 28 else: thisDepth = 1 29 if thisDepth > maxDepth: maxDepth = thisDepth 30 return maxDepth 31 32 def plotNode(nodeTxt, centerPt, parentPt, nodeType): 33 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', 34 xytext=centerPt, textcoords='axes fraction', 35 va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) 36 37 def plotMidText(cntrPt, parentPt, txtString): 38 xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] 39 yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] 40 createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) 41 42 def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on 43 numLeafs = getNumLeafs(myTree) #this determines the x width of this tree 44 depth = getTreeDepth(myTree) 45 # firstStr = myTree.keys()[0] #the text label for this node should be this 46 lt = list(myTree.keys()) 47 firstStr = lt[0] 48 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) 49 plotMidText(cntrPt, parentPt, nodeTxt) 50 plotNode(firstStr, cntrPt, parentPt, decisionNode) 51 secondDict = myTree[firstStr] 52 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD 53 for key in secondDict.keys(): 54 if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes 55 plotTree(secondDict[key],cntrPt,str(key)) #recursion 56 else: #it's a leaf node print the leaf node 57 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW 58 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) 59 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) 60 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD 61 #if you do get a dictonary you know it's a tree, and the first element will be another dict 62 63 def createPlot(inTree): 64 fig = plt.figure(1, facecolor='white') 65 fig.clf() 66 axprops = dict(xticks=[], yticks=[]) 67 createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks 68 #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 69 plotTree.totalW = float(getNumLeafs(inTree)) 70 plotTree.totalD = float(getTreeDepth(inTree)) 71 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; 72 plotTree(inTree, (0.5,1.0), '') 73 plt.show() 74 75 #def createPlot(): 76 # fig = plt.figure(1, facecolor='white') 77 # fig.clf() 78 # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 79 # plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode) 80 # plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode) 81 # plt.show() 82 83 def retrieveTree(i): 84 listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, 85 {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} 86 ] 87 return listOfTrees[i] 88 89 #createPlot(thisTree)
1 if __name__=='__main__': 2 """代码测试""" 3 dataset ,labels = createDataSet() 4 s = calcShannonEnt(dataset) 5 retdataset = splitDataSet(dataset, 1, 1) 6 bestFeature = chooseBestFeatureToSplit(dataset) 7 myTree = createTree(dataset,labels) 8 createPlot(myTree)
结果:

参考图书:机器学习实战【美】Peter Harrington 著 李锐 李鹏 曲亚东 王斌 译
浙公网安备 33010602011771号