机器学习002——决策树

决策树

1 决策树概念

A decision tree is a decision support tool that uses a tree-like graph or model of decisions and their possible consequences, including chance event outcomes, resource costs, and utility. It is one way to display an algorithm.
Decision trees are commonly used in operations research, specifically in decision analysis, to help identify a strategy most likely to reach a goal, but are also a popular tool in machine learning.

决策树是使用树状图或决策模型的一个决策支持工具,其可能的产生的效果,包括机会事件结果,资源成本和效用。 它是显示算法的一种方式。


1.1 邮件处理系统的效率可以由决策树来表示:

这里写图片描述
上图显示的是:根据一些特征来分类,看是不是迫切需要处理的还是需要处理的。

猜题游戏:参与游戏的一方可以确定一个答案,另一个人提问20个问题来确定答案,问题的答案只能用对错来回答,也可以使用决策树。

平常许多情况都需要用到决策树,决策树是最经常使用的数据挖掘算法。决策时不一定非要二叉树,多叉树也是可以的,每个结点显示的信息也不一定只能是一个,可以多条信息显示在一起。

决策树将一些事情很直观地显示出来,使其处理问题更加地简单。


1.2 决策树特点

  • Pros:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不想管特征数据
  • Cons:可能会产生过度匹配问题
  • Works with:数值型和标称型

2 决策树构造

决策树分支构造步骤如下:

  1. 检测数据集中得每个子项是否属于同一分类,“是”则返回节点并结束,“否”则2
  2. 寻找划分数据集的最好特征,划分数据集,创建分支节点
  3. 对每个划分的子集,进行1操作

创建分支的伪代码函数createBranch():

检测数据集中得每个子项是否属于同一类
if so return 类标签
ELSE
     寻找划分数据集的最好特征
     划分数据集
     创建分支节点
	       for 每个划分的子集
	       调用函数createBranch()并增加返回结果到分支节点中
	 return 分支节点

决策树算法可以采用二分法、ID3算法划分数据集

3 信息熵、信息增益以及基尼指数

3.1 信息熵

  • 集合信息的度量方式称为香农熵或者熵
  • 熵是对信息不确定的度量
  • 熵定义为信息的期望值
  • 一个系统越有序,则信息熵越低,相反一个系统越是混乱,则它的信息熵越高。

如果待分类的事物可能划分在多个分类之中,则符号xi的信息定义为:这里写图片描述
计算熵时我们需要计算所有类别所有可能值包含的信息期望值:
这里写图片描述

n为分类的数目


3.2 信息增益

  • 划分数据集的大原则是:将无序的数据变得更加地有序。
  • 在划分数据集之前之后信息发生的变化称为信息增益

以天气预报的例子来详细说明信息增益的含义
这里写图片描述
学习目标是play或者not play
一共有14个样例,9个正例和5个负例,当前信息的熵计算如下:

Entropy(S) = - 9/14 * log2(9/14) - 5/14 * log2(5/14)

在决策树分类问题中,信息增益就是决策树在进行属性划分前后信息的差值。假设利用属性Outlook来分类,那么如下图
这里写图片描述

划分后,数据被分为三个部分,各个分支的信息熵计算如下:
这里写图片描述

划分后的信息熵为:
这里写图片描述
这里写图片描述

信息增益的计算公式:
这里写图片描述
本例的信息增益:
这里写图片描述

在决策树的每一个非叶子节点划分之前,先计算每一个属性所带来的信息增益,选择最大信息增益的属性来划分,因为信息增益越大,区分样本能力就越强,越具有代表性,这是一种自顶向下的贪心策略。这也是ID3算法的核心思想。

3.3 基尼指数

  • CART算法主要使用Gini指数。
  • 在CART算法中,基尼不纯度表示一个随机选中的样本在子集中被分错的可能性。
  • 基尼不纯度为这个样本被选中的概率乘以它被分错的概率
  • 假设y的可能取值为{1,2,......,m},令fi是样本被赋予i的概率,则基尼指数可以通过如下计算:

这里写图片描述

这里写图片描述

这里写图片描述


4 C4.5算法&&ID3算法&&CART算法

4.1 C4.5算法

ID3算法的思想如上例所示,C4.5算法是机器学习中另一个分类决策树算法,它是基于ID3算法进行改进后的一种重要算法,改进有如下几个要点:

  • 用信息增益率来选择属性。ID3选择属性用的是子树的信息增益,这里可以用很多方法来定义信息,ID3使用的是熵(entropy, 熵是一种不纯度度量准则),也就是熵的变化值,而C4.5用的是信息增益率。
  • 在决策树构造过程中进行剪枝,因为某些具有很少元素的结点可能会使构造的决策树过适应(Overfitting),如果不考虑这些结点可能会更好。
  • 对非离散数据也能处理。
  • 能够对不完整数据进行处理。

上述例子使用C4.5算法:
计算分裂信息度量H(V):

H(Outlook) = - 5/14 * log2(5/14) - 4/14 * log2(4/14) - 5/14 * log2(5/14)

信息增益率:

IGR(Outlook) = Entropy(S|T) / H(Outlook)

4.2 CART算法

  • CART又称分类回归树,CART可以用作分类,也可用作回归。
  • 用作分类的时候使用Gini指数最小化原则,选择特征,递归地构造二叉树。
  • 用作回归树时用平方误差最小化作为选择特征的准则

天气预报的CART算法的具体计算过程如下:

Outlook sunny overcast rain
YES 2 4 3
NO 3 0 2

Gini(Sunny) = 1 - (2/5)^2 - (3/5)^2
Gini(Overcast) =1 - (4/4)^2 - (0/4)^2
Gini(rain) = 1 - (3/5)^2 - (2/5)^2
Gini= 5/14Gini(Sunny) + 4/14Gini(Overcast)+5/14*Gini(rain)


对离散值如{x,y,z},则在该属性上的划分有三种情况
这里写图片描述
空集和全集的划分除外

天气预报的例子的计算情况如下:

Outlook sunny or overcast rain
YES 6 3
NO 3 2

然后再进行计算


5 海洋生物数据处理

海洋中有5个动物,特征有:不浮出水面是否可以生产,是否有脚蹼,是否属于鱼类,我们可以依赖这些特征进行划分数据,但是要依据哪个特征来划分数据需要计算信息增益。
这里写图片描述

划分数据集的数据路径
这里写图片描述

trees.py的代码:

# _*_ coding: UTF-8 -*-

from math import log
import operator


'''
    输入数据集,这是鱼鉴定数据集
'''
def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']

    return dataSet, labels

'''
    calcShannonEnt计算香农熵
'''
def calcShannonEnt(dataSet):
    # 计算数据集中实例的总数
    numEntries = len(dataSet)
    labelCounts = {}
    """
            为所有可能分类创建字典,它的键值是最后一列的数值,
            如果当前键值不存在,则扩展字典并将当前键值加入字典
            每个键值都记录了当前类别出现的次数
    """
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    """
            使用所有类标签的发生频率计算类别出现的概率,用这个概率计算香农熵
            统计所有类标签发生的次数
    """
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)  # log base 2
    return shannonEnt


"""
按照给定特征划分数据集
"""
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]  # chop out axis used for splitting
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

"""
    选择最好的数据集划分方式
"""
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1  # the last column is used for the labels
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0;
    bestFeature = -1
    for i in range(numFeatures):  # iterate over all the features
        featList = [example[i] for example in dataSet]  # create a list of all the examples of this feature
        uniqueVals = set(featList)  # get a set of unique values
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy  # calculate the info gain; ie reduction in entropy
        if (infoGain > bestInfoGain):  # compare this to the best gain so far
            bestInfoGain = infoGain  # if better than current best, set to best
            bestFeature = i
    return bestFeature  # returns an integer


def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

"""
    创建树的函数代码
"""
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]  # stop splitting when all of the classes are equal
    if len(dataSet[0]) == 1:  # stop splitting when there are no more features in dataSet
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    del (labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]  # copy all of labels, so trees don't mess up existing labels
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

def classify(inputTree, featLabels, testVec):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict):
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else:
        classLabel = valueOfFeat
    return classLabel


def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'w')
    pickle.dump(inputTree, fw)
    fw.close()


def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

treePlotter.py的代码

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")


def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[
                    key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[
                    key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)


def plotTree(myTree, parentPt, nodeTxt):  # if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  # this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0]  # the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[
                    key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes
            plotTree(secondDict[key], cntrPt, str(key))  # recursion
        else:  # it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


# if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # no ticks
    # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()


# def createPlot():
#    fig = plt.figure(1, facecolor='white')
#    fig.clf()
#    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
#    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
#    plt.show()

def retrieveTree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                   ]
    return listOfTrees[i]

    # createPlot(thisTree)

命令行执行语句:

目录: /Users/shasha/PycharmProjects/shang/trees.py

Last login: Thu Jun 22 12:00:46 on ttys000
bogon:~ shasha$ cd PycharmProjects/shang
bogon:shang shasha$ python
Python 2.7.10 (default, Oct 23 2015, 19:19:21) 
[GCC 4.2.1 Compatible Apple LLVM 7.0.0 (clang-700.0.59.5)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import trees
>>> myDat,labels=trees.createDataSet()
>>> myDat
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
>>> trees.calcShannonEnt(myDat)
0.9709505944546686
>>> myDat[0][-1]='maybe'
>>> myDat
[[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
>>> trees.calcShannonEnt(myDat)
1.3709505944546687
>>> a=[1,2,3]
>>> b=[4,5,6]
>>> a.append(b)
>>> a
[1, 2, 3, [4, 5, 6]]
>>> a=[1,2,3]
>>> a.extend(b)
>>> a
[1, 2, 3, 4, 5, 6]
>>> reload(trees)
<module 'trees' from 'trees.pyc'>
>>> myDat,labels=trees.createDataSet()
>>> myDat
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
>>> trees.splitDataSet(myDat,0,1)
[[1, 'yes'], [1, 'yes'], [0, 'no']]
>>> trees.splitDataSet(myDat,0,0)
[[1, 'no'], [1, 'no']]
>>> reload(trees)
<module 'trees' from 'trees.pyc'>
>>> myDat,labels=trees.createDataSet()
>>> trees.chooseBestFeatureToSplit(myDat)
0
>>> myDat
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
>>> reload(trees)
<module 'trees' from 'trees.pyc'>
>>> myDat,labels=trees.createDataSet()
>>> myTree = trees.createTree(myDat,labels)
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> reload(treePlotter)
<module 'treePlotter' from 'treePlotter.py'>
>>> treePlotter.retrieveTree(0)
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> myTree=treePlotter.retrieveTree(0)
>>> treePlotter.getNumLeafs(myTree)
3
>>> treePlotter.getTreeDepth(myTree)
2
>>> reload(treePlotter)
<module 'treePlotter' from 'treePlotter.pyc'>
>>> myTree=treePlotter.retrieveTree(0)
>>> treePlotter.createPlot(myTree)
/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/matplotlib/patches.py:3046: RuntimeWarning: invalid value encountered in double_scalars
  ddx = pad_projected * dx / cp_distance
/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/matplotlib/patches.py:3047: RuntimeWarning: invalid value encountered in double_scalars
  ddy = pad_projected * dy / cp_distance
/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/matplotlib/patches.py:3050: RuntimeWarning: invalid value encountered in double_scalars
  dx = dx / cp_distance * head_dist
/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/matplotlib/patches.py:3051: RuntimeWarning: invalid value encountered in double_scalars
  dy = dy / cp_distance * head_dist
>>> myTree['no surfacing'][3]='maybe'
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}
>>> treePlotter.createPlot(myTree)
>>> myDat,labels=trees.createDataSet()
>>> labels
['no surfacing', 'flippers']
>>> myTree=treePlotter.retrieveTree(0)
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> trees.classify(myTree,labels,[1,0])
'no'
>>> trees.classify(myTree,labels,[1,1])
'yes'
>>> trees.storeTree(myTree,'classifierStorage.txt')
>>> trees.grabTree('classifierStorage.txt')
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> 

这里写图片描述

这里写图片描述

代码以及使用的数据链接:
决策树

posted @ 2017-06-22 15:20  Gssol  阅读(1058)  评论(3编辑  收藏  举报