Python 决策树算法原理及实现

  决策树是机器学习常见的算法,最著名的代表是ID3、C4.5和CART算法。也有其他的如:OC1算法、ID4h和ITI。他们如何选择最优划分属性不尽相同,ID3决策树是以信息增益为准则来选择划分属性的,C4.5决策树是以信息增益率来选择最优划分属性的,CART是以基尼指数来选择划分属性的,其他决策树算法还没有弄清楚。

(1)信息论中的相关概念

熵:熵原本是物理学中的一个定义,后来香农将其引申到了信息论领域,用来表示信息量的大小。信息量越大(分类越不“纯净”),对应的熵值就越大,反之亦然。信息熵的计算公式如下:

 

 

 

在实际应用中,会将概率𝑝_𝑘的值用经验概率替换,所以经验信息熵可以表示为:

 

 

 

条件熵:

信息增益:

信息增益率:

基尼指数:

 (二)代码实现

 1 from math import log
 2 import operator
 3 
 4 def createDataSet():
 5     """创建数据集"""
 6     dataSet = [[1, 1, 'yes'],
 7                [1, 1, 'yes'],
 8                [1, 0, 'no'],
 9                [0, 1, 'no'],
10                [0, 1, 'no']]
11     labels = ['no surfacing','flippers']  #标签类别
12     return dataSet, labels
 1 def calcShannonEnt(dataSet):
 2     """计算信息熵"""
 3     numEntries = len(dataSet)
 4     labelCounts = {}
 5     for featVec in dataSet: #the the number of unique elements and their occurance,遍历一次得到dataset的一行数据
 6         currentLabel = featVec[-1]
 7         if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
 8         labelCounts[currentLabel] += 1
 9     shannonEnt = 0.0
10     for key in labelCounts:
11         prob = float(labelCounts[key])/numEntries
12         shannonEnt -= prob * log(prob,2) #log base 2
13     return shannonEnt
 1 def splitDataSet(dataSet, axis, value):
 2     """返回axis特征下==value的数据样本
 3     ,axis为数据集第几个特征
 4     ,value为特征下的某个值"""
 5     retDataSet = []
 6     for featVec in dataSet:
 7         if featVec[axis] == value:
 8             reducedFeatVec = featVec[:axis]     #chop out axis used for splitting
 9             reducedFeatVec.extend(featVec[axis+1:])
10             retDataSet.append(reducedFeatVec)
11     return retDataSet
 1 def chooseBestFeatureToSplit(dataSet):
 2     """以信息增益度的划分方法来选择最优属性"""
 3     numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels
 4     baseEntropy = calcShannonEnt(dataSet)
 5     bestInfoGain = 0.0; bestFeature = -1
 6     for i in range(numFeatures):        #iterate over all the features
 7         featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
 8         uniqueVals = set(featList)       #get a set of unique values
 9         newEntropy = 0.0
10         for value in uniqueVals:
11             subDataSet = splitDataSet(dataSet, i, value)
12             prob = len(subDataSet)/float(len(dataSet))
13             newEntropy += prob * calcShannonEnt(subDataSet)     
14         infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy
15         if (infoGain > bestInfoGain):       #compare this to the best gain so far
16             bestInfoGain = infoGain         #if better than current best, set to best
17             bestFeature = i
18     return bestFeature                      #returns an integer
1 def majorityCnt(classList):
2     """用来筛选数据最后一个特征时的划分,即投票表决法"""
3     classCount={}
4     for vote in classList:
5         if vote not in classCount.keys(): classCount[vote] = 0
6         classCount[vote] += 1
7     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
8     return sortedClassCount[0][0]
 1 def createTree(dataSet,labels):
 2     """生成决策树"""
 3     classList = [example[-1] for example in dataSet]
 4     
 5     """迭代终止条件"""
 6     if classList.count(classList[0]) == len(classList): 
 7         return classList[0]#stop splitting when all of the classes are equal
 8     if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
 9         return majorityCnt(classList)
10     
11     bestFeat = chooseBestFeatureToSplit(dataSet)
12     bestFeatLabel = labels[bestFeat]
13     myTree = {bestFeatLabel:{}}
14     del(labels[bestFeat])
15     featValues = [example[bestFeat] for example in dataSet]
16     uniqueVals = set(featValues)
17     for value in uniqueVals:
18         subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels
19         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
20     return myTree                            
1 if __name__=='__main__':
2     """代码测试"""
3     dataset ,labels = createDataSet()
4     s = calcShannonEnt(dataset)
5     retdataset = splitDataSet(dataSet, 1, 1)
6     bestFeature = chooseBestFeatureToSplit(dataSet)
7     myTree = createTree(dataSet,labels)
8     print(myTree)

 结果:

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

二、生成树可视化

 1 import matplotlib.pyplot as plt
 2 
 3 decisionNode = dict(boxstyle="sawtooth", fc="0.8")
 4 leafNode = dict(boxstyle="round4", fc="0.8")
 5 arrow_args = dict(arrowstyle="<-")
 6 
 7 def getNumLeafs(myTree):
 8     numLeafs = 0
 9 #     firstStr = myTree.keys()
10     firstSides = list(myTree.keys())
11     firstStr = firstSides[0]#找到输入的第一个元素
12     secondDict = myTree[firstStr]
13     for key in secondDict.keys():
14         if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
15             numLeafs += getNumLeafs(secondDict[key])
16         else:   numLeafs +=1
17     return numLeafs
18 
19 def getTreeDepth(myTree):
20     maxDepth = 0
21 #     firstStr = myTree.keys()[0]
22     firstSides = list(myTree.keys())
23     firstStr = firstSides[0]#找到输入的第一个元素
24     secondDict = myTree[firstStr]
25     for key in secondDict.keys():
26         if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
27             thisDepth = 1 + getTreeDepth(secondDict[key])
28         else:   thisDepth = 1
29         if thisDepth > maxDepth: maxDepth = thisDepth
30     return maxDepth
31 
32 def plotNode(nodeTxt, centerPt, parentPt, nodeType):
33     createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
34              xytext=centerPt, textcoords='axes fraction',
35              va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
36     
37 def plotMidText(cntrPt, parentPt, txtString):
38     xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
39     yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
40     createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
41 
42 def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
43     numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
44     depth = getTreeDepth(myTree)
45 #     firstStr = myTree.keys()[0]     #the text label for this node should be this
46     lt = list(myTree.keys())
47     firstStr = lt[0]
48     cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
49     plotMidText(cntrPt, parentPt, nodeTxt)
50     plotNode(firstStr, cntrPt, parentPt, decisionNode)
51     secondDict = myTree[firstStr]
52     plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
53     for key in secondDict.keys():
54         if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
55             plotTree(secondDict[key],cntrPt,str(key))        #recursion
56         else:   #it's a leaf node print the leaf node
57             plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
58             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
59             plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
60     plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
61 #if you do get a dictonary you know it's a tree, and the first element will be another dict
62 
63 def createPlot(inTree):
64     fig = plt.figure(1, facecolor='white')
65     fig.clf()
66     axprops = dict(xticks=[], yticks=[])
67     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
68     #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
69     plotTree.totalW = float(getNumLeafs(inTree))
70     plotTree.totalD = float(getTreeDepth(inTree))
71     plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
72     plotTree(inTree, (0.5,1.0), '')
73     plt.show()
74 
75 #def createPlot():
76 #    fig = plt.figure(1, facecolor='white')
77 #    fig.clf()
78 #    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
79 #    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
80 #    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
81 #    plt.show()
82 
83 def retrieveTree(i):
84     listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
85                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
86                   ]
87     return listOfTrees[i]
88 
89 #createPlot(thisTree)
1 if __name__=='__main__':
2     """代码测试"""
3     dataset ,labels = createDataSet()
4     s = calcShannonEnt(dataset)
5     retdataset = splitDataSet(dataset, 1, 1)
6     bestFeature = chooseBestFeatureToSplit(dataset)
7     myTree = createTree(dataset,labels)
8     createPlot(myTree)

结果:

参考图书:机器学习实战【美】Peter Harrington 著  李锐 李鹏 曲亚东 王斌 译

posted on 2019-08-17 23:08  LiErRui  阅读(417)  评论(0)    收藏  举报

导航