1 from math import log
2 import operator
3
4 def createDataSet():
5 dataSet = [[1,1,"yes"],
6 [1,1,"yes"],
7 [1,0,"no"],
8 [0,1,"no"],
9 [0,1,"no"]]
10 labels = ["no surfacing","flippers"]
11 return dataSet,labels
12 def calcShannonEnt(dataSet):
13 numEntries = len(dataSet)
14 labelCounts = {}
15 for featVec in dataSet:
16 currentLabel = featVec[-1]
17 if currentLabel not in labelCounts.keys():
18 labelCounts[currentLabel] = 0
19 labelCounts[currentLabel] += 1
20 shannonEnt = 0.0
21 for key in labelCounts:
22 prob = float(labelCounts[key]) / numEntries
23 shannonEnt -= prob * log(prob,2)
24 return shannonEnt
25 def splitdataSet(dataSet,axis,value):
26 retDataSet = []
27 for featVec in dataSet:
28 if featVec[axis] == value:
29 reducedFeatVec = featVec[:axis]
30 reducedFeatVec.extend(featVec[axis + 1:])
31 retDataSet.append(reducedFeatVec)
32 return retDataSet
33 def chooseBestFeatureToSplit(dataSet):
34 numFeatures = len(dataSet[0]) - 1
35 baseEntropy = calcShannonEnt(dataSet)
36 bestInfoGain = 0.0;bestFeature = -1
37 for i in range(numFeatures):
38 featList = [example[i] for example in dataSet]
39 uniqueVals = set(featList)
40 newEntropy = 0.0
41 for value in uniqueVals:
42 subDataSet = splitdataSet(dataSet,i,value)
43 prob = len(subDataSet) / float(len(dataSet))
44 newEntropy += prob * calcShannonEnt(subDataSet)
45 infoGain = baseEntropy - newEntropy
46 if (infoGain > bestInfoGain):
47 bestInfoGain = infoGain
48 bestFeature = i
49 return bestFeature
50 def majorityCnt(classList):
51 classCount = {}
52 for vote in classList:
53 if vote not in classCount.keys():
54 classCount[vote] = 0
55 classCount[vote] += 1
56 sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
57 return sortedClassCount[0][0]
58 def createTree(dataSet,labels):
59 classList = [example[-1] for example in dataSet]
60 if classList.count(classList[0]) == len(classList):
61 return classList[0]
62 if len(dataSet[0]) == 1:
63 return majorityCnt(classList)
64 bestFeat = chooseBestFeatureToSplit(dataSet)
65 bestFeatLabel = labels[bestFeat]
66 myTree = {bestFeatLabel:{}}
67 del(labels[bestFeat])
68 featValues = [example[bestFeat] for example in dataSet]
69 uniqueVals = set(featValues)
70 for value in uniqueVals:
71 subLabels = labels[:]
72 myTree[bestFeatLabel][value] = createTree(splitdataSet(dataSet,bestFeat,value),subLabels)
73 return myTree
74 if __name__ == "__main__":
75 myDat,labels = createDataSet()
76 #print calcShannonEnt(myDat)
77 #print splitdataSet(myDat,0,1)
78 #print chooseBestFeatureToSplit(myDat)
79 myTree = createTree(myDat,labels)
80 print myTree