决策树之python实现C4.5算法

  • 原理

  C4.5算法是在ID3算法上的一种改进,它与ID3算法最大的区别就是特征选择上有所不同,一个是基于信息增益比,一个是基于信息增益。

  之所以这样做是因为信息增益倾向于选择取值比较多的特征(特征越多,条件熵(特征划分后的类别变量的熵)越小,信息增益就越大);因此在信息增益下面加一个分母,该分母是当前所选特征的熵,注意:这里而不是类别变量的熵了。

  这样就构成了新的特征选择准则,叫做信息增益比。为什么加了这样一个分母就会消除ID3算法倾向于选择取值较多的特征呢?

  因为特征取值越多,该特征的熵就越大,分母也就越大,所以信息增益比就会减小,而不是像信息增益那样增大了,一定程度消除了算法对特征取值范围的影响。

  • 实现

  在算法实现上,C4.5算法只是修改了信息增益计算的函数calcShannonEntOfFeature和最优特征选择函数chooseBestFeatureToSplit。

  calcShannonEntOfFeature在ID3的calcShannonEnt函数上加了个参数feat,ID3中该函数只用计算类别变量的熵,而calcShannonEntOfFeature可以计算指定特征或者类别变量的熵。

  chooseBestFeatureToSplit函数在计算好信息增益后,同时计算了当前特征的熵IV,然后相除得到信息增益比,以最大信息增益比作为最优特征。

  在划分数据的时候,有可能出现特征取同一个值,那么该特征的熵为0,同时信息增益也为0(类别变量划分前后一样,因为特征只有一个取值),0/0没有意义,可以跳过该特征。

 

  • 代码

  1 #coding=utf-8
  2 import operator
  3 from math import log
  4 import time
  5 import os, sys
  6 import string
  7 
  8 def createDataSet(trainDataFile):
  9     print trainDataFile
 10     dataSet = []
 11     try:
 12         fin = open(trainDataFile)
 13         for line in fin:
 14             line = line.strip()
 15             cols = line.split('\t')
 16             row = [cols[1], cols[2], cols[3], cols[4], cols[5], cols[6], cols[7], cols[8], cols[9], cols[10], cols[0]]
 17             dataSet.append(row)
 18             #print row
 19     except:
 20         print 'Usage xxx.py trainDataFilePath'
 21         sys.exit()
 22         labels = ['cip1', 'cip2', 'cip3', 'cip4', 'sip1', 'sip2', 'sip3', 'sip4', 'sport', 'domain']
 23     print 'dataSetlen', len(dataSet)
 24         return dataSet, labels
 25 
 26 #calc shannon entropy of label or feature
 27 def calcShannonEntOfFeature(dataSet, feat):
 28     numEntries = len(dataSet)
 29     labelCounts = {}
 30     for feaVec in dataSet:
 31         currentLabel = feaVec[feat]
 32         if currentLabel not in labelCounts:
 33             labelCounts[currentLabel] = 0
 34         labelCounts[currentLabel] += 1
 35     shannonEnt = 0.0
 36     for key in labelCounts:
 37         prob = float(labelCounts[key])/numEntries
 38         shannonEnt -= prob * log(prob, 2)
 39     return shannonEnt
 40 
 41 def splitDataSet(dataSet, axis, value):
 42     retDataSet = []
 43     for featVec in dataSet:
 44         if featVec[axis] == value:
 45             reducedFeatVec = featVec[:axis]
 46             reducedFeatVec.extend(featVec[axis+1:])
 47             retDataSet.append(reducedFeatVec)
 48     return retDataSet
 49     
 50 def chooseBestFeatureToSplit(dataSet):
 51     numFeatures = len(dataSet[0]) - 1    #last col is label
 52     baseEntropy = calcShannonEntOfFeature(dataSet, -1)
 53     bestInfoGainRate = 0.0
 54     bestFeature = -1
 55     for i in range(numFeatures):
 56         featList = [example[i] for example in dataSet]
 57         uniqueVals = set(featList)
 58         newEntropy = 0.0
 59         for value in uniqueVals:
 60             subDataSet = splitDataSet(dataSet, i, value)
 61             prob = len(subDataSet) / float(len(dataSet))
 62             newEntropy += prob *calcShannonEntOfFeature(subDataSet, -1)    #calc conditional entropy
 63         infoGain = baseEntropy - newEntropy
 64        iv = calcShannonEntOfFeature(dataSet, i)
 65         if(iv == 0):    #value of the feature is all same,infoGain and iv all equal 0, skip the feature
 66         continue
 67        infoGainRate = infoGain / iv
 68         if infoGainRate > bestInfoGainRate:
 69             bestInfoGainRate = infoGainRate
 70             bestFeature = i
 71     return bestFeature
 72             
 73 #feature is exhaustive, reture what you want label
 74 def majorityCnt(classList):
 75     classCount = {}
 76     for vote in classList:
 77         if vote not in classCount.keys():
 78             classCount[vote] = 0
 79         classCount[vote] += 1
 80     return max(classCount)         
 81     
 82 def createTree(dataSet, labels):
 83     classList = [example[-1] for example in dataSet]
 84     if classList.count(classList[0]) ==len(classList):    #all data is the same label
 85         return classList[0]
 86     if len(dataSet[0]) == 1:    #all feature is exhaustive
 87         return majorityCnt(classList)
 88     bestFeat = chooseBestFeatureToSplit(dataSet)
 89     bestFeatLabel = labels[bestFeat]
 90     if(bestFeat == -1):        #特征一样,但类别不一样,即类别与特征不相关,随机选第一个类别做分类结果
 91     return classList[0] 
 92     myTree = {bestFeatLabel:{}}
 93     del(labels[bestFeat])
 94     featValues = [example[bestFeat] for example in dataSet]
 95     uniqueVals = set(featValues)
 96     for value in uniqueVals:
 97         subLabels = labels[:]
 98         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
 99     return myTree
100     
101 def main():
102     if(len(sys.argv) < 3):
103     print 'Usage xxx.py trainSet outputTreeFile'
104     sys.exit()
105     data,label = createDataSet(sys.argv[1])
106     t1 = time.clock()
107     myTree = createTree(data,label)
108     t2 = time.clock()
109     fout = open(sys.argv[2], 'w')
110     fout.write(str(myTree))
111     fout.close()
112     print 'execute for ',t2-t1
113 if __name__=='__main__':
114     main()

 

 

本文来自于:

为什么要改进成C4.5算法

谢谢博主

posted @ 2018-04-07 21:32  寒杰士  阅读(2297)  评论(0编辑  收藏  举报