python中函数式编程的尝试——决策树

  在上次的数据挖掘作业 Apriori算法 中尝试使用了map reduce lambda后第一次感觉到语法糖这个东西果然和 AstralWind说的一样,太特么甜了。结果是在这次的作业中滥用了一下。决策树的算法参考了博客 CodingLabs

  下面是代码:

#!usr/bin/ python
# -*- coding: utf-8 -*-
import math
import copy

class dtree:
    def __init__(self, filename='logs.txt', Algorithm='ID3'):
        origin_records = open(filename).readlines()
        records = map(lambda x:x.split(), origin_records)
        records = filter(lambda x:x!=None and x[0][0]!='#', records)
        self.rcd_list = list()
        for rcd in records:
            self.rcd_list.append({
                'magic' : 'magic',
                'age'   : rcd[0],
                'income': rcd[1],
                'student':rcd[2],
                'cdt_rat':rcd[3],
                'buys_cpt':rcd[4]
                })
        self.column_set = set(self.rcd_list[0].keys())
        self.link = lambda strDir,spl='\t': reduce(lambda s1,s2:s1+spl+s2, [s for s in strDir if s!='magic'])
        self.tree = dict()
        self.ID3(self.rcd_list, self.tree)

    def ID3(self, rcd_list, boot):
        rst = 'buys_cpt'
        x_log2x = lambda x:x*math.log(x,2) if x!=0 else 0

        get_val_list = lambda col:[rcd[col] for rcd in rcd_list]
        get_val_set = lambda col:set(get_val_list(col))
        get_cnt = lambda col,val:get_val_list(col).count(val)

        get_list_by_val = lambda col,val:[rcd for rcd in rcd_list if rcd[col] == val]
        cnt_rst = lambda col,val,rst_val:len([rcd for rcd in get_list_by_val(col,val) if rcd[rst]==rst_val])
        get_prob = lambda col,val:get_cnt(col,val)/float(len(rcd_list))

        foo = lambda col,val,rst_val:cnt_rst(col,val,rst_val)/float(get_cnt(col,val)) if get_cnt(col,val)!=0 else 0
        info_D = lambda col,val="magic":-sum([x_log2x(foo(col,val,rst_val)) for rst_val in get_val_set(rst)])
        #infoL_D = lambda col:sum([get_prob(col,val)*info_D(col,'magic') for val in get_val_set(col)])
        infoL_D = lambda col:sum([get_prob(col,val)*info_D(col,val) for val in get_val_set(col)])

        gain = lambda col:info_D('magic') - infoL_D(col)
        cols = set([ key for key in rcd_list[0] if key!='magic' and key!='buys_cpt'])

        for col in cols:
            print 'info_D  ',col , '\t','%0.2f'%infoL_D(col)
        max_col = reduce(lambda col1,col2:col1 if gain(col1)>gain(col2) else col2, cols)
        print '>>max_col:', max_col
        max_col_val = get_val_set(max_col)

        for val in max_col_val:
            new_list = copy.deepcopy(get_list_by_val(max_col, val))
            is_certain = lambda l:len(set([rcd[rst] for rcd in l])) == 1
            if is_certain(new_list) or len(cols) == 1:
                #如果属性用完了,子集还不是纯净集,会选择多数的结果
                cnt_v = lambda v:cnt_rst(max_col,val,v)
                final_rst = reduce(lambda v1,v2:v1 if cnt_v(v1)>=cnt_v(v2) else v2, get_val_set(rst))
                boot[max_col+":"+val] = "result:"+final_rst
            else:
                for rcd in new_list:
                    del rcd[max_col]
                boot[max_col+":"+val] = dict()
                self.ID3(new_list, boot[max_col+":"+val])

    def print_tree(self):
        def prt(tree,depth):
            if type(tree) == type(dict()):
                for key in tree:
                    print ".\t"*depth,key
                    prt(tree[key], depth+1)
            else:
                print ".\t"*int(depth),tree
        prt(self.tree, 0)

if __name__ == '__main__':
    dtree('logs.txt').print_tree()
    x = raw_input("press Enter to continue")

  使用的测试数据logs.txt:

#age    income    student    cdt_rat    buys_cpt
<=30    high    no    fair    no
<=30    high    no    excellt    no
31~40    high    no    fair    yes
>40    medium    no    fair    yes
>40    low    yes    fair    yes
>40    low    yes    excellt    no
31~40    low    yes    excellt    yes
<=30    medium    no    fair    no
<=30    low    yes    fair    yes
>40    medium    yes    fair    yes
<=30    medium    yes    excellt    yes
31~40    medium    no    excellt    yes
31~40    high    yes    fair    yes
>40    medium    no    excellt    no

  最后结果:

  

  代码基本没可读性,不过我感觉这种算法实现出来很难有可读性,如果函数展开实现的话,代码拿到手上搞清楚调用关系都比较困难,再加上函数不好取名(我只是把公式用python写出来了,实际的公式意义搞的也不是很清楚),所以我注释都没加,中间info_D刚开始写错了,后来同学告诉我答案错了才反应过来。

  python里有些语法确实使程序好写好读了许多,比如:

 1 x, y = y, x    #交换x, y
 2 
 3 l = [rcd[0] for rcd in rcds]    #rcds中每个元素的第一个元素组成的列表
 4 
 5 s=set([1,2,3])    #s可以使任何可以遍历的元素为str的容器
 6 
 7 link = lambda s, spl='\t': reduce(lambda s1,s2:s1+spl+s2, s)
 8 #使用link显示列表或是集合之类的可以省下很多代码量
 9 print 'list:',link(s)
10 #等价于
11 print 'list:',
12 for item in s:
13     print item,'\t',
14 print

而且上面使用for显示的还会多显示个\t,是tab的时候还好,如果是用其它可视的字符(比如'—'或者'.')来连接的话,会比较烦。

准备在sinaapp平台上自己搭建的博客感觉越来越遥远啊, 最近要抓紧时间补充各种知识,可能还有一些实习面试笔试什么的要考虑,先在博客园写着吧,免得以后不想写了

 

posted @ 2013-05-09 17:56  王维维  阅读(281)  评论(0)    收藏  举报