Python实现ID3算法

  自己用Python写的数据挖掘中的ID3算法,现在觉得Python是实现算法的最好工具:

  先贴出ID3算法的介绍地址http://wenku.baidu.com/view/cddddaed0975f46527d3e14f.html

  自己写的ID3算法

  1 from __future__ import division
  2 import math
  3 
  4 table = {'age': {'young', 'middle', 'old'}, 'income': {'high', 'middle', 'low'},
  5          'student': {'yes', 'no'}, 'credit': {'good', 'superior'}, 'buy computer': {'yes', 'no'}}
  6 attrIndex = {'age': 0, 'income': 1, 'student': 2, 'credit': 3, 'buy computer': 4}
  7 attrList = ['age', 'income', 'student', 'credit']
  8 allDataSet = [
  9     ['young', 'high', 'no', 'good', 'no'], ['young', 'high', 'no', 'superior', 'no'],
 10     ['middle', 'high', 'no', 'superior', 'yes'], ['old', 'middle', 'no', 'good', 'yes'],
 11     ['young', 'middle', 'no', 'good', 'no'], ['young', 'low', 'yes', 'good', 'yes'],
 12     ['middle', 'high', 'yes', 'good', 'yes'], ['old', 'middle', 'no', 'superior', 'no'],
 13     ['young', 'high', 'yes', 'good', 'yes'], ['middle', 'middle', 'no', 'good', 'no']
 14 ]
 15 
 16 #求熵
 17 def entropy(attr, dataSet):
 18     valueCount = {v: {'yes': 0, 'no': 0, 'count': 0} for v in table[attr]}
 19     for row in dataSet:
 20         vName = row[attrIndex[attr]]
 21         decAttrVal = row[attrIndex['buy computer']] # 'yes' or 'no'
 22         valueCount[vName]['count'] = valueCount[vName]['count'] + 1
 23         valueCount[vName][decAttrVal] = valueCount[vName][decAttrVal] + 1
 24     infoMap = {v: 0 for v in table[attr]}
 25     for v in valueCount:
 26         if valueCount[v]['count'] == 0:
 27             infoMap[v] = 0
 28         else:
 29             p1 = valueCount[v]['yes'] / valueCount[v]['count']
 30             p2 = valueCount[v]['no'] / valueCount[v]['count']
 31             infoMap[v] = - ((0 if p1 == 0 else p1 * math.log(p1, 2)) + (0 if p2 == 0 else p2 * math.log(p2, 2)))
 32     s = 0
 33     for v in valueCount:
 34         s = s + valueCount[v]['count']
 35     propMap = {v: (valueCount[v]['count'] / s) for v in valueCount}
 36     i = 0
 37     for v in valueCount:
 38         i = i + infoMap[v] * propMap[v]
 39     return i
 40 
 41 #定义节点的数据结构
 42 class Node(object):
 43     def __init__(self, attrName):
 44         if attrName != '':
 45             self.attr = attrName
 46             self.childNodes = {v:Node('') for v in table[attrName]}
 47 
 48 #数据筛选
 49 def filtrate(dataSet, condition):
 50     result = []
 51     for row in dataSet:
 52         if row[attrIndex[condition['attr']]] == condition['val']:
 53             result.append(row)
 54     return result
 55 #求最大信息熵
 56 def maxEntropy(dataSet, attrList):
 57     if len(attrList) == 1:
 58         return attrList[0]
 59     else:
 60         attr = attrList[0]
 61         maxE = entropy(attr, dataSet)
 62         for a in attrList:
 63             if maxE < entropy(a, dataSet):
 64                 attr = a
 65         return attr
 66 #判断构建是否结束,当所有的决策属性都相等的时候,就不用在构建决策树了
 67 def endBuild(dataSet):
 68     if len(dataSet) == 1:
 69         return True
 70     buy = dataSet[0][attrIndex['buy computer']]
 71     for row in dataSet:
 72         if buy != row[attrIndex['buy computer']]:
 73             return False
 74 #构建决策树
 75 def buildDecisionTree(dataSet, root, attrList):
 76     if len(attrList) == 0 or endBuild(dataSet):
 77         root.attr = 'buy computer'
 78         root.result = dataSet[0][attrIndex['buy computer']]
 79         root.childNodes = {}
 80         return
 81     attr = root.attr
 82     for v in root.childNodes:
 83         childDataSet = filtrate(dataSet, {"attr":attr, "val":v})
 84         if len(childDataSet) == 0:
 85             root.childNodes[v] = Node('buy computer')
 86             root.childNodes[v].result = 'no'
 87             root.childNodes[v].childNodes = {}
 88             continue
 89         else:
 90             childAttrList = [a for a in attrList]
 91             childAttrList.remove(attr)
 92             if len(childAttrList) == 0:
 93                 root.childNodes[v] = Node('buy computer')
 94                 root.childNodes[v].result = childDataSet[0][attrIndex['buy computer']]
 95                 root.childNodes[v].childNodes = {}
 96             else:
 97                 childAttr = maxEntropy(childDataSet, childAttrList)
 98                 root.childNodes[v] = Node(childAttr)
 99                 buildDecisionTree(childDataSet, root.childNodes[v], childAttrList)
100 #预测结果
101 def predict(root, row):
102     if root.attr == 'buy computer':
103         return root.result
104     root = root.childNodes[row[attrIndex[root.attr]]]
105     return predict(root, row)
106 
107 rootAttr = maxEntropy(allDataSet, attrList)
108 rootNode = Node(rootAttr)
109 print rootNode.attr
110 buildDecisionTree(allDataSet, rootNode, attrList)
111 print predict(rootNode, ['old', 'low', 'yes', 'good'])

         欢迎大家提出建议

posted on 2013-11-03 17:16  Arts&Crafts  阅读(1663)  评论(0编辑  收藏  举报

导航