Linux下用matplotlib画决策树

1、trees = {'no surfacing': { 0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

2、从我的文件trees.txt里读的决策树,也是一个递归字典表示

#coding=utf-8
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt  # 载入 pyplot API
import os, sys
import time

decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 注(a)
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")  # 箭头样式

def plotNode(Nodename, centerPt, parentPt, nodeType):  #  centerPt节点中心坐标  parentPt 起点坐标
    creatPlot.ax1.annotate(Nodename, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) # 注(b)

def getNumleafs(mytree): # 获得叶节点数目,输入为我们前面得到的树(字典)
    Numleafs = 0 # 初始化
    firstStr = list(mytree.keys())[0] # 注(a) 获得第一个key值(根节点) 'no surfacing'
    secondDict = mytree[firstStr]  # 获得value值 {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}
    for key in secondDict.keys(): #  键值:0 和 1
        if type(secondDict[key]).__name__=='dict': # 判断如果里面的一个value是否还是dict
            Numleafs += getNumleafs(secondDict[key]) # 递归调用
        else:
            Numleafs += 1
    return Numleafs

def getTreeDepth(mytree):
    maxDepth = 0
    firstStr = list(mytree.keys())[0]
    secondDict = mytree[firstStr]
    for key in secondDict.keys(): #  键值:0 和 1
        thisDepth = 0
        if type(secondDict[key]).__name__=='dict': # 判断如果里面的一个value是否还是dict
            thisDepth = 1 + getTreeDepth(secondDict[key]) # 递归调用
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

def plotMidText(cntrPt, parentPt, txtString):   #  在两个节点之间的线上写上字
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    creatPlot.ax1.text(xMid, yMid, txtString)  # text() 的使用

def plotTree(myTree, parentPt, nodeName):  # 画树
    numleafs = getNumleafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0]
    cntrPt = (plotTree.xOff+(0.5/plotTree.totalw+float(numleafs)/2.0/plotTree.totalw), plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeName) 
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD # 减少y的值,将树的总深度平分,每次减少移动一点(向下,因为树是自顶向下画的)
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
        plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalw
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def creatPlot(inTree):  # 使用的主函数
    fig = plt.figure(figsize=(200,200), facecolor='white')
    fig.clf()  # 清空绘图区
    axprops = dict(xticks=[], yticks=[]) # 创建字典 存储=====有疑问???=====
    creatPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #  ===参数的意义?===
    plotTree.totalw = float(getNumleafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))  # 创建两个全局变量存储树的宽度和深度
    print 'tree width =', plotTree.totalw 
    print 'tree height =', plotTree.totalD 
    plotTree.xOff = -0.5/plotTree.totalw # 追踪已经绘制的节点位置 初始值为 将总宽度平分 在取第一个的一半 
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5,1.0), '')  # 调用函数,并指出根节点源坐标 
    plt.savefig('images/tree2.png', format='png',  dpi=100)

trees = []
try:
        fin = open(sys.argv[1])
        line = fin.readline()
        trees = eval(line)
        #print trees
except:
        print 'load tree error'
        raise
if(len(sys.argv) == 1):
    trees = {'no surfacing': { 0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
t1 = time.clock()
creatPlot(trees)
t2 = time.clock()
print t2 - t1

 

ps:参考博客[http://blog.csdn.net/ifruoxi/article/details/53150129]

posted on 2017-04-18 22:00  WOTGL  阅读(456)  评论(0)    收藏  举报

导航