#coding:utf-8
from numpy import *
from math import *
import operator
def file2matrix(filename):
fr=open(filename)
lines=fr.readlines()
lenth=len(lines)
rematrix=zeros((lenth,7))
label=["seze","gendi","qiaoshen","wenli","qibu","chugan"]#西瓜特征集
index=0
for line in lines:
line=line.strip()
lin=line.split(" ")
rematrix[index:]=lin
index=index+1
return rematrix,label
def singlesplit(data,axis,value):
newlistt=[]
for feat in data:
if feat[axis]==value:
newlist=list([feat[axis]])
newlist.extend([feat[-1]])
newlistt.append(newlist)
return newlistt
def allsplit(data):
alldata=[]
baseEntry=calcshannon(data)
ordermax=0.0
bestfuture=-1
lenth=len(data[0])
for i in range(lenth-1):
b=[example[i] for example in data]#取得特征的所有取值
newEntry=0.0
uniq=set(b)#特征的可能取值
for j in uniq:
cooldata=singlesplit(data,i,j)
prob=len(cooldata)/float(len(data))
newEntry+=prob*calcshannon(cooldata)
info=baseEntry-newEntry
if(info>ordermax):
ordermax=info
bestfuture=i
return bestfuture
def calcshannon(data):
simplenum=len(data)
tempdict={}
for line in data:
tail=line[-1]
if tail not in tempdict.keys():
tempdict[tail]=0
tempdict[tail]+=1
shannonEntry=0.0
for k in tempdict.keys():
prob=tempdict[k]/float(simplenum)
shannonEntry-=prob*log(prob,2)
return shannonEntry
def selectbigger(label):
calcdict={}
for line in label:
if line not in calcdict.keys():
calcdict[line]=0
calcdict+=1
Getsorted=sorted(calcdict.iteritems(),key=operator.itemgetter(1),reverse=True)
return Getsorted[0][0]
def createTree(data,label):
labellist=[tt[-1] for tt in data]
if labellist.count(labellist[0])==len(labellist):#所有样本均为同类
return labellist[0]
if len(data[0])==1:#特征集为空
return selectbigger(labellist)
bestfuture=allsplit(data)
bestlabel=label[bestfuture]
tree={bestlabel:{}}#用字典递归建立树
del(label[bestfuture])
bestval=[tt[bestfuture] for tt in data]
uniq=set(bestval)
for value in uniq:
sublabel=label
tree[bestlabel][value]=createTree(singlesplit(data,bestfuture,value),sublabel)
return tree
def classifier(inputree,featurelabel,clsdata):
firststr=inputree.keys()[0]
secondict=inputree[firststr]
classlabel=''
featindex=featurelabel.index(firststr)
for key in secondict.keys():
if clsdata[featindex]==key:
if type(secondict[key]).__name__=='dict':#当节点为字典是,继续递归,否则返回当前的节点值
classlabel=classifier(secondict[key],featurelabel,clsdata)
else:
classlabel=secondict[key]
return classlabel
dataset,label=file2matrix("out.txt")
mytree=createTree(dataset,label)
dataset,label=file2matrix("out.txt")#createTree中label元素已被全部删除,而classifier要用label
print classifier(mytree,label,[3,1,1,3,3,1])