决策树

1. 简介

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。

缺点:可能会产生过度匹配问题

适用:数值型和标称型

一般流程:

  1. 收集数据:任何方式
  2. 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化
  3. 分析数据:任何方法
  4. 训练算法:构造树的数据结构
  5. 测试算法:使用经验树计算错误率
  6. 使用算法:适用于任何监督学习算法

1. 信息增益

熵:信息的期望值

n:分类的数目

p(xi):选择该分类的概率

python中实现

建立trees.py文件,创建calcShannonEnt()函数,计算香农熵

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCount = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCount.keys():
            labelCount[currentLabel] = 0
        labelCount[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCount:
        prob = float(labelCount[key]) / numEntries
        shannonEnt -= prob * log(prob,2) #同公式 prob为p(xi) 使用for循环累计计算
    return shannonEnt
  1. 计算数据集中实例总数 ---可以用pandas的count_values直接计算
  2. 计算分类的概率
  3. 计算香农熵

建立模拟数据集:

1 def createDataSet():
2     dataSet = [[1,1,'yes'],
3                [1,1,'yes'],
4                [1,0,'no'],
5                 [0,1,'no'],
6                 [0,1,'no'],]
7     labels = ['no surfacing','flippers']
8     return dataSet,labels

在python命令提示符下输入命令:

import trees

myDat,labels = trees.createDataSet()

myDat
Out[351]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]

trees.calcShannonEnt(myDat)
Out[352]: 0.9709505944546686

myDat[0][-1] = 'maybe'

myDat
Out[354]: [[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]

trees.calcShannonEnt(myDat)
Out[355]: 1.3709505944546687

得到熵之后,就可以按照获取最大信息增益的方法划分数据集

2. 划分数据集

对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集市最好的划分方式

 在trees.py文件中添加splitDataSet函数

1 def splitDataSet(dataSet,axis,value):
2     retDataSet = []
3     for featVec in dataSet:
4         if featVec[axis] == value:
5             reducedFeatVec = featVec[:axis]
6             reducedFeatVec.extend(featVec[axis+1:])
7             retDataSet.append(reducedFeatVec)
8     return retDataSet

python技巧:

append与extend方法:

#append方法
a = [1,2,3]

b = [4,5,6]

a.append(b)

a
Out[359]: [1, 2, 3, [4, 5, 6]]

#extend方法

a.extend(b)

a
Out[365]: [1, 2, 3, 4, 5, 6]

在python命令提示符内输入下述命令:

 1 reload(trees)
 2 Out[366]: <module 'trees' from 'trees.pyc'>
 3 
 4 myDat,label = trees.createDataSet()
 5 
 6 myDat
 7 Out[368]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
 8 
 9 trees.splitDataSet(myDat,0,1)
10 Out[369]: [[1, 'yes'], [1, 'yes'], [0, 'no']]
11 
12 trees.splitDataSet(myDat,0,0)
13 Out[370]: [[1, 'no'], [1, 'no']]

遍历整个数据集,循环计算香农熵和splitDataSet()函数,找到最好的特征划分方式

 1 def chooseBestFeatureToSplit(dataSet):
 2     numFeatures = len(dataSet[0]) - 1
 3     baseEntropy = calcShannonEnt(dataSet)
 4     bestInfoGain = 0.0
 5     bestFeature = -1
 6     for i in range(numFeatures):
 7         featList = [example[i] for example in dataSet]
 8         uniqueVals = set(featList)
 9         newEntropy = 0.0
10         for value in uniqueVals:
11             subDataSet = splitDataSet(dataSet,i,value)
12             prob = len(subDataSet)/float(len(dataSet))
13             newEntropy += prob * calcShannonEnt(subDataSet)
14         infoGain = baseEntropy - newEntropy
15         if (infoGain > bestInfoGain):
16             bestInfoGain = infoGain
17             bestFeature = i
18     return bestFeature
  1.  计算dataset的整体香农熵 baseEntropy
  2. 遍历所有特征的香农熵与baseEntropy做比较,选出最好的特征

在python命令提示符下运行:

1 myDat,labels = trees.createDataSet()
2 
3 trees.chooseBestFeatureToSplit(myDat)
4 Out[541]: 0
5 
6 myDat
7 Out[542]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]

结果显示第0个特征是最好的用于划分数据集的特征

3. 递归构建决策树

 

 在trees.py文件中店家majorCnt函数、createTree()函数

1 def majorityCnt(classList):
2     classCount = {}
3     for vote in classList:
4         if vote not in classCount.keys(): 
5             classCount[vote] = 0
6         classCount[vote] += 1
7     sortedClassCount = sorted(classCount.items(),
8                               key=operator.itemgetter(1),reverse = True)
9     return sortedClassCount[0][0]

 

 1 def createTree(dataSet,labels):
 2     classList = [example[-1] for example in dataSet]
 3     if classList.count(classList[0] == len(classList)):
 4         return classList[0]
 5     if len(dataSet[0]) == 1:
 6         return majorityCnt(classList)
 7     bestFeat = chooseBestFeatureToSplit(dataSet)
 8     bestFeatLabel = labels[bestFeat]
 9     myTree = {bestFeatLabel:{}}
10     del(labels[bestFeat])
11     featValues = [example[bestFeat] for example in dataSet]
12     uniqueVals = set(featValues)
13     for value in uniqueVals:
14         subLabels = labels[:]
15         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
16     return myTree

 

posted @ 2017-02-21 16:50  rockchen  阅读(759)  评论(0)    收藏  举报