机器学习-决策树

  (1) 说明: 通过给定的因素构建一个决策树, 后续可根据构建出的决策树进行判断。

  (2) demo:一个简单的数据文件

 1 $ cat WeatherTraining.csv
 2 outlook,temperature,humidity,windy,play
 3 overcast,hot,high,FALSE,yes
 4 overcast,cool,normal,TRUE,yes
 5 overcast,mild,high,TRUE,yes
 6 overcast,hot,normal,FALSE,yes
 7 rainy,mild,high,FALSE,yes
 8 rainy,cool,normal,FALSE,yes
 9 rainy,cool,normal,TRUE,no
10 rainy,mild,normal,FALSE,yes
11 rainy,mild,high,TRUE,no
12 sunny,hot,high,FALSE,no
13 sunny,hot,high,TRUE,no
14 sunny,mild,high,FALSE,no
15 sunny,cool,normal,FALSE,yes
16 sunny,mild,normal,TRUE,yes

  (3)代码:

  1 <?php
  2 /**
  3  * 构建决策树
  4  */
  5 class buildDecisionTree {
  6     const DECISION_TREE_CACHE = __DIR__ . '/decision_tree_serialize.';          //序列化决策树的存储路径
  7     
  8     private $element;
  9     private $sortElement;
 10     private $resultKey;
 11     private $decisionTree;
 12     private $tariningDataPath;
 13     
 14     public function build($tariningData) {
 15         //文件是否存在
 16         $this->tariningDataPath = __DIR__ . '/' . $tariningData;
 17         if (!file_exists($this->tariningDataPath)) {
 18             exit('tarining data file not exists');
 19         }
 20         
 21         //判断是否可以直接从缓存中取
 22         $crc32 = crc32(file_get_contents($this->tariningDataPath));
 23         $cachePath = self::DECISION_TREE_CACHE . $crc32;
 24         if (file_exists($cachePath)) {
 25             $cacheObj = unserialize(file_get_contents($cachePath));
 26             return $cacheObj;
 27         }
 28         
 29         //格式化数据
 30         $trainingData = $this->formateData();
 31         
 32         //信息熵
 33         $resArr = array_column($trainingData, $this->resultKey);
 34         $resEntropy = $this->calEntropy($resArr);
 35         
 36         //信息增益
 37         $informationGain = $this->celInformationGain($trainingData, $resEntropy);
 38         
 39         //倒序, 信息增益越大, 对结果影响越大, 越靠近根
 40         arsort($informationGain);
 41         $this->sortElement = array_keys($informationGain);
 42         $this->decisionTree = $this->buildTree($this->sortElement, $trainingData);
 43         
 44         //写缓存
 45         $fp = fopen($cachePath, 'w');
 46         fwrite($fp, serialize($this));
 47         return $this;
 48     }
 49     
 50     //判断是否外出
 51     public function isOut($argv) {
 52         array_shift($argv);
 53         if (count($argv) !== count($this->element)) {
 54             exit('param error');
 55         }
 56         
 57         //根据sortElement的顺序查, 策略树就是按照这个顺序构建的
 58         $treeTmp = $this->decisionTree;
 59         $argv = array_combine($this->element, $argv);
 60         foreach ($this->sortElement as $e) {
 61             $value = strtolower($argv[$e]);
 62             if (!isset($treeTmp->nodes[$value])) {
 63                 exit ('make decision error:can\'t find node ' . $value);
 64             }
 65             $treeTmp = $treeTmp->nodes[$value];            
 66             //找到确定值
 67             if (!is_object($treeTmp)) {
 68                 return $treeTmp;
 69             }
 70         }
 71     }
 72     
 73     //构建决策树
 74     private function buildTree($element, $trainingData) {
 75         //当前因素结点
 76         $currentElement = array_shift($element);
 77         $tree = $this->demoFactory($currentElement);
 78         $classifyData = $this->classify($trainingData, $currentElement);
 79         foreach ($classifyData as $condition => $result) {
 80             $sameRes = $this->isAllSame($result);
 81             if ($sameRes === false) {
 82                 //筛选训练集之后,下面继续加节点
 83                 $nextTrainingData = $this->filteTrainingData($trainingData, $currentElement, $condition);
 84                 $tree->nodes[$condition] = $this->buildTree($element, $nextTrainingData);
 85             } else {
 86                 //有确定值
 87                 $tree->nodes[$condition] = $sameRes;
 88             }
 89         }
 90         return $tree;
 91     }
 92     
 93     //根据某一因素的某一情况筛选
 94     private function filteTrainingData($trainingData, $element, $condition) {
 95         $ret = array();
 96         foreach ($trainingData as $info) {
 97             $info[$element] == $condition && $ret[] = $info;
 98         }
 99         return $ret;
100     }
101     
102     //是否有确定结果(distinct是否为1)
103     private function isAllSame($result) {
104         $result = array_flip($result);
105         if (count($result) === 1) {
106             return key($result);
107         }
108         return false;
109     }
110     
111     //生产一个简单的模型
112     private function demoFactory($element) {
113         $tree = new stdClass();
114         $tree->nodes = array();
115         $tree->element = $element;
116         return $tree;
117     }
118     
119     //计算信息增益
120     private function celInformationGain($trainingData, $resEntropy) {
121         //信息增益
122         $informationGain = array();
123         foreach ($this->element as $element) {
124             //分类
125             $classifyData = $this->classify($trainingData, $element);
126             //求熵
127             $entropyData = array_map([$this, 'calEntropy'], $classifyData);
128             //乘积加和
129             $product = array_map(function($c, $e) use ($trainingData) {
130                 return count($c) / count($trainingData) * $e;
131             }, $classifyData, $entropyData);
132             $informationGain[$element] = $resEntropy - array_sum($product);
133         }
134         return $informationGain;
135     }
136     
137     //分类
138     private function classify($data, $key) {
139         $classify = array();
140         foreach ($data as $info) {
141             $val = $info[$key];
142             $infoRes = $info[$this->resultKey];
143             $classify[$val][] = $infoRes;
144         }
145         return $classify;
146     }
147     
148     
149     //格式化数据
150     private function formateData() {
151         //数据文件很小, 可以直接都读进来
152         $ret = array();
153         $dataStr = file_get_contents($this->tariningDataPath);
154         $dataArr = array_filter(explode("\n", $dataStr));
155         
156         foreach ($dataArr as $info) {
157             $info = explode(',', strtolower($info));
158             if (!$this->element) {
159                 //第一行为每一列的标注, 其他为数据
160                 $this->element = $info;
161                 continue;
162             }
163             $ret[] = array_combine($this->element, $info);
164         }
165         
166         //最后一列为结果值
167         $this->resultKey = array_pop($this->element);
168         return $ret;
169     }
170     
171     //计算信息熵
172     private function calEntropy($res) {
173         $statistics = array();
174         $count = count($res);
175         $entropy = 0;
176         foreach ($res as $v) {
177             !isset($statistics[$v]) && $statistics[$v] = 0;
178             $statistics[$v]++;
179         }
180         
181         foreach ($statistics as $num) {
182             $quotient  = $num / $count;
183             $entropy += $quotient * log($quotient, 2);
184         }
185         return abs($entropy);
186     }
187 }
188 
189 //执行
190 echo (new buildDecisionTree())->build('WeatherTraining.csv')->isOut($argv);

 

posted @ 2020-03-31 01:48  Dahouzi  阅读(152)  评论(0编辑  收藏  举报