ML in Action 决策树

Project Address:

https://github.com/TheOneAC/ML.git

dataset in ML/ML_ation/tree

决策树

  • 计算复杂度低,中间值缺失不敏感,可理解不相关数据
  • 可能过度匹配(过度分类)
  • 适用:数值型和标称型

决策树伪代码createbranch

检测数据集中子项是否全部属于一类
    if so return class_tag
    else 寻找数据集最佳划分特征
            划分数据集
            创建分支节点
            对每一个子集,递归调用createbranch
        返回分支节点

递归结束条件:所有属性遍历完,或者数据集属于同一分类

香农熵

def calcShannonEnt(dataSet):
	numEntries = len(dataSet)
	labelCounts = {}
	for featVec in dataSet:
		currentLabel = featVec[-1]
		if currentLabel not in labelCounts.keys():
			labelCounts[currentLabel] = 0
		labelCounts[currentLabel] += 1
	shannonEnt = 0.0
	for key in labelCounts:
		prob = float(labelCounts[key])/numEntries
		shannonEnt -= prob * log(prob,2)
	return shannonEnt

数据及划分与最优选择(熵最小)

def splitDataSet(dataSet, axis, value):
	retDataSet = []
	for featVec in dataSet:
		if featVec[axis] == value:
			reduceFeatVec = featVec[:axis]
			reduceFeatVec.extend(featVec[axis + 1:])
			retDataSet.append(reduceFeatVec)
	return retDataSet


def chooseBestFeatureToSplit(dataSet):
	numFeatures = len(dataSet[0])- 1
	baseEntropy = calcShannonEnt(dataSet)
	bestInfoGain = 0.0
	bestFeature = -1
	for i in range(numFeatures):
		featList = [example[i] for example in dataSet]
		uniqueVals = set(featList)
		newEntropy = 0.0
		for value in uniqueVals:
			subDataSet = splitDataSet(dataSet, i, value)
			prob = len(subDataSet)/float(len(dataSet))
			newEntropy += prob * calcShannonEnt(subDataSet)
		infoGain = baseEntropy - newEntropy
		if infoGain > bestInfoGain:
			baseInfoGain = infoGain
			bestFeature = i
	return bestFeature

所有标签用尽无法确定类标签时: 多数表决决定子叶分类


def majorityCnt(classList):
	classCount = {}
	for vote in classList:
		if vote not in classCount.keys(): classCount[vote] = 0
		classCount[vote] += 1
	sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)
	return sortedClassCount[0][0]

创建树


def createTree(dataSet, labels):
	classList = [example[-1] for example in dataSet]
	if classList.count(classList[0]) == len(classList):
		return classList[0]
	if len(dataSet[0]) == 1:
		return majorityCnt(classList)
	
	bestFeat = chooseBestFeatureToSplit(dataSet)
	bestFeatureLabel = labels[bestFeat]
	myTree = {bestFeatureLabel:{}}
	del(labels[bestFeat])
	featValues = [example[bestFeat] for example in dataSet]
	uniqueVals = set(featValues)
	for value in uniqueVals:
		subLabels = labels[:]
		myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet, bestFeat,value), subLabels)
	return myTree

测试

def classify(inputTree,featLabels,testVec):
	firstStr = inputTree.keys()[0]
	secondDict = inputTree[firstStr]

	featIndex = featLabels.index(firstStr)
	for key in secondDict.keys():
		if testVec[featIndex] == key:
			if type(secondDict[key]).__name__=='dict':
				classLabel = classify(secondDict[key],featLabels,testVec)
			else:
				classLabel = secondDict[key]
	return classLabel
>>> import trees
>>> myDat,labels=trees.createDataSet()
>>> labels
['no surfacing', 'flippers']
>>> myTree=treePlotter.retrieveTree (0)
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> trees.classify(myTree,labels,[1,0])
'no'
>>> trees.classify(myTree,labels,[1,1])
'yes'

存储与重载

def storeTree(inputTree, filename):
	import pickle
	fw = open(filename, 'w')
	pickle.dump(inputTree,fw)
	fw.close()

def grabTree(filename):
	import pickle
	fr = open(filename)
	return pickle.load(fr)

test

#!/usr/bin/python
import trees

myDat,labels = trees.createDataSet()

myTree = trees.createTree(myDat, labels)

trees.storeTree(myTree,'classifierStorage.txt')

print(trees.grabTree('classifierStorage.txt'))

图形化显示树结构

#!/usr/bin/python

import matplotlib.pyplot as plt 

decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
	createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = "axes fraction",
	xytext = centerPt, textcoords = "axes fraction",
	va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)

创建节点

def createPlot():
	fig = plt.figure(1, facecolor = "white")
	fig.clf()
	createPlot.ax1 = plt.subplot(111, frameon = False)
	plotNode("a decision node",(0.5, 0.1), (0.1, 0.5), decisionNode)
	plotNode("a leaf node",(0.8, 0.1), (0.3, 0.8), leafNode)
	plt.show()

python command line run command as this

import treeplotter
treePlotter.createPlot()
  • result like this
    图片标题
def getNumLeafs(myTree):
	numLeafs = 0
	firstStr = myTree.keys()[0]
	secondDict = myTree[firstStr]
	for key in secondDict.keys():
		if type(secondDict[key]).__name__ == 'dict':
			numLeafs += getNumleafs(secondDict[key])
		else: numLeafs +=1
	return numLeafs

def getTreeDepth(myTree):
	maxDepth = 0
	firstStr = myTree.keys()[0]
	secondDict = myTree[firstStr]
	for key in secondDict.keys():
		if type(secondDict[key]).__name__ == 'dict':
			thisDepth = 1+ getTreeDepth(secondDict[key])
		else:
			thisDepth = 1
		if thisDepth > maxDepth: maxDepth = thisDepth
	return maxDepth


def retrieveTree(i):
	listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': \
		{0: 'no', 1: 'yes'}}}},
		{'no surfacing': {0: 'no', 1: {'flippers': \
		{0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
		]
	return listOfTrees[i]


def plotMidText(cntrPt, parentPt, txtString):
	xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
	yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
	createPlot.ax1.text(xMid, yMid, txtString)
	


def plotTree(myTree, parentPt, nodeTxt):
	numLeafs = getNumLeafs(myTree)
	
	depth = getTreeDepth(myTree)
	
	firstStr = myTree.keys()[0]
	cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,\
			plotTree.yOff)
	plotMidText(cntrPt, parentPt, nodeTxt)
	
	plotNode(firstStr, cntrPt, parentPt, decisionNode)
	
	secondDict = myTree[firstStr]
	plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
	
	for key in secondDict.keys():
		if type(secondDict[key]).__name__=='dict':
			plotTree(secondDict[key],cntrPt,str(key))
		else:
			plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
			plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),
			cntrPt, leafNode)
			plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
	plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
	
def createPlot(inTree):
	fig = plt.figure(1, facecolor='white')
	fig.clf()
	axprops = dict(xticks=[], yticks=[])
	createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
	plotTree.totalW = float(getNumLeafs(inTree))
	plotTree.totalD = float(getTreeDepth(inTree))
	plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
	plotTree(inTree, (0.5,1.0), '')
	plt.show()

图片标题

扩展测试 lens.py

Project Address: ` https://github.com/TheOneAC/ML.git`
    dataset:  `lens.txt in ML/ML_ation/tree`
#!/usr/bin/python

import trees
import treePlotter

fr = open("lenses.txt")
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels=['age', 'prescript', 'astigmatic', 'tearRate']

lensesTree = trees.createTree(lenses,lensesLabels)
print(lensesTree)

treePlotter.createPlot(lensesTree)

图片标题

posted @ 2017-04-10 22:24  zeroArn  阅读(406)  评论(0编辑  收藏  举报