Python实现决策树C4.5算法

 为什么要改进成C4.5算法

  • 原理

  C4.5算法是在ID3算法上的一种改进,它与ID3算法最大的区别就是特征选择上有所不同,一个是基于信息增益比,一个是基于信息增益。

  之所以这样做是因为信息增益倾向于选择取值比较多的特征(特征越多,条件熵(特征划分后类别变量的熵)越小,信息增益就越大);因此在信息增益下面加一个分母,该分母是当前所选特征的熵,注意:这里而不是类别变量的熵了。

  这样就构成了新的特征选择准则,叫做信息增益比。为什么加了这样一个分母就会消除ID3算法倾向于选择取值较多的特征呢?

  因为特征取值越多,该特征的熵就越大,分母也就越大,所以信息增益比就会减小,而不是像信息增益那样增大了,一定程度消除了算法对特征取值范围的影响。

  • 实现

  在算法实现上,C4.5算法只是修改了信息增益计算的函数calcShannonEntOfFeature和最优特征选择函数chooseBestFeatureToSplit。

  calcShannonEntOfFeature在ID3的calcShannonEnt函数上加了个参数feat,ID3中该函数只用计算类别变量的熵,而calcShannonEntOfFeature可以计算指定特征或者类别变量的熵

  chooseBestFeatureToSplit函数在计算好信息增益后,同时计算了当前特征的熵IV,然后相除得到信息增益比,以最大信息增益比作为最优特征。

  在划分数据的时候,有可能出现特征取同一个值,那么该特征的熵为0,同时信息增益也为0(类别变量划分前后一样,因为特征只有一个取值),0/0没有意义,可以跳过该特征。

#coding=utf-8
import operator
from math import log
import time
import os, sys
import string

def createDataSet(trainDataFile):
    print trainDataFile
    dataSet = []
    try:
        fin = open(trainDataFile)
        for line in fin:
            line = line.strip()
            cols = line.split('\t')
            row = [cols[1], cols[2], cols[3], cols[4], cols[5], cols[6], cols[7], cols[8], cols[9], cols[10], cols[0]]
            dataSet.append(row)
            #print row
    except:
        print 'Usage xxx.py trainDataFilePath'
        sys.exit()
        labels = ['cip1', 'cip2', 'cip3', 'cip4', 'sip1', 'sip2', 'sip3', 'sip4', 'sport', 'domain']
    print 'dataSetlen', len(dataSet)
        return dataSet, labels

#calc shannon entropy of label or feature
def calcShannonEntOfFeature(dataSet, feat):
    numEntries = len(dataSet)
    labelCounts = {}
    for feaVec in dataSet:
        currentLabel = feaVec[feat]
        if currentLabel not in labelCounts:
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt

def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet
    
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1    #last col is label
    baseEntropy = calcShannonEntOfFeature(dataSet, -1)
    bestInfoGainRate = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob *calcShannonEntOfFeature(subDataSet, -1)    #calc conditional entropy
        infoGain = baseEntropy - newEntropy
       iv = calcShannonEntOfFeature(dataSet, i)
        if(iv == 0):    #value of the feature is all same,infoGain and iv all equal 0, skip the feature
        continue
       infoGainRate = infoGain / iv
        if infoGainRate > bestInfoGainRate:
            bestInfoGainRate = infoGainRate
            bestFeature = i
    return bestFeature
            
#feature is exhaustive, reture what you want label
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    return max(classCount)         
    
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) ==len(classList):    #all data is the same label
        return classList[0]
    if len(dataSet[0]) == 1:    #all feature is exhaustive
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    if(bestFeat == -1):        #特征一样,但类别不一样,即类别与特征不相关,随机选第一个类别做分类结果
    return classList[0] 
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree
    
def main():
    if(len(sys.argv) < 3):
    print 'Usage xxx.py trainSet outputTreeFile'
    sys.exit()
    data,label = createDataSet(sys.argv[1])
    t1 = time.clock()
    myTree = createTree(data,label)
    t2 = time.clock()
    fout = open(sys.argv[2], 'w')
    fout.write(str(myTree))
    fout.close()
    print 'execute for ',t2-t1
if __name__=='__main__':
    main()

 

posted on 2017-04-21 20:01  WOTGL  阅读(10389)  评论(2编辑  收藏  举报

导航