决策树
假设现在创建了一个数据集,代码如下:
1 def createDataSet(): 2 dataSet = [[1,1,'yes'], 3 [1,1,'yes'], 4 [1,0,'no'], 5 [0,1,'no'], 6 [0,1,'no']] 7 labels = ['ct1','ct2'] 8 return dataSet, labels
计算数据集的香农熵代码:
from math import log def calcShannonEnt(dataSet): numEntries = len(dataset) //求出数据有多少个对象,即有多少行,计算实例总数 labelCounts = {} for featVec in dataset: //对数据集逐行求最后一类的数据,并将统计最后一列数据的数目 currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): //创建一个字典,键值是最后一列的数值 labelCountspcurrentLabel] = 0 //当前键值不存在,则扩展字典并将此键值加入字典 labelCounts[currentLabel] += 1 //每个键值都记录了当前类别出现的次数 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries //使用统计出的最后一列的数据来计算所有类标签出现的概率 shannonEnt -= prob * log(prob,2) //下面的公式 return shannonEnt
H=-∑p(xi)log(2,p(xi)) (i=1,2,..n)
熵越高,则混合的数据越多,数据的不纯度越大。得到熵,就可以按照获取最大信息增益的方法来划分数据集。
划分数据集的代码
1 def splitDataSet(dataSet, axis, values): 2 retDataSet = [] 3 for featVec in dataSet: 4 if featVec[axis] == value: //判断axis列的值是否为value 5 reducedFeatVec = featVec[:axis] //[:axis]表示前axis行,即若axis为2,就是取featVec的前axis行 6 reducedFeatVec.extend(featVec[axis+1:]) //[axis+1:]表示从跳过axis+1行,取接下来的数据 7 retDataSet.append(reducedFeatVec) 8 return retDataSet
执行完上面的代码,数据就会将符合值判定的行取出来,然后将这些行里用来判定值的列去除,剩下的数据就是划分完的数据集
接下来的代码是选择最好的数据集划分方式
1 def chooseBestFeatureToSplit(dataSet): 2 numFeatures = len(dataSet[0]) - 1 3 baseEntropy = calcShannonEnt(dataSet) 4 bestInfoGain = 0.0 5 bestFeature = -1 6 for i in range(numFeatures): 7 featList = [example[i] for example in dataSet] 8 uniqueVals = set(featList) //创建了一个列表,里面的元素是dataSet所有的元素,但不重复 9 newEntropy = 0.0 10 for value in uniqueVals: 11 subDataSet = splitDataSet(dataSet, i, value) 12 prob = len(subDataSet)/float(len(dataSet)) 13 newEntropy += prob * calcShannonEnt(subDataSet) //计算按每个数据特征来划分的数据集的熵 14 infoGain = baseEntropy - newEntropy 15 if (infoGain > bestInfoGain): 16 bestInfoGain = infoGain 17 bestFeature = i //判断出哪种划分方式得到最大的信息增益,且得出该划分方式所选择的特征 18 return bestFeature
在第一次划分之后,数据将被传递到下一个节点,在此节点上可以再次划分数据,因此可以用递归来处理数据集。递归结束的条件是,程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据都属于叶子节点的分类。下面的代码用来判断递归何时结束。
1 import operator 2 def majorityCnt(classList): //参数classList在下面创建树的代码里,是每一行的最后一个特征 3 classCount = {} 4 for vote in classList: //将特征里的元素添加到新建的字典里作为键值,并统计该键值出现次数 5 if vote not in classCount.keys(): classCount[vote] = 0 6 classCount[vote] += 1 7 sortedClassCount = sorted(classCount.iteritems(), key=operater.itemgetter(1), reverse=True) 8 return sortedClassCount[0][0] //返回出现次数最多的键值
下面是创建树的代码,该代码调用了上面多个模块。
1 def createTree(dataSet, labels): 2 classList = [example[-1] for example in dataSet] 3 if classList.count(classList[0]) == len(classList): 4 return classList[0] 5 if len(dataSet[0]) == 1: 6 return majorityCnt(classList) 7 bestFeat = chooseBestFeatureToSplit(dataSet) 8 bestFeatLabel = labels[bestFeat] 9 myTree = {bestFeatLabel:{}} 10 del(labels[bestFeat]) 11 featValues = [example[bestFeat] for example in dataSet] 12 uniqueVals = set(featValues) 13 for value in uniqueVals: 14 subLabels = labels[:] 15 myTree[bestFeatLabel][value] = creatTree(splitDataSet(dataSet, bestFeat, value), subLabels) 16 return myTree
浙公网安备 33010602011771号