机器学习-决策树
(1) 说明: 通过给定的因素构建一个决策树, 后续可根据构建出的决策树进行判断。
(2) demo:一个简单的数据文件
1 $ cat WeatherTraining.csv 2 outlook,temperature,humidity,windy,play 3 overcast,hot,high,FALSE,yes 4 overcast,cool,normal,TRUE,yes 5 overcast,mild,high,TRUE,yes 6 overcast,hot,normal,FALSE,yes 7 rainy,mild,high,FALSE,yes 8 rainy,cool,normal,FALSE,yes 9 rainy,cool,normal,TRUE,no 10 rainy,mild,normal,FALSE,yes 11 rainy,mild,high,TRUE,no 12 sunny,hot,high,FALSE,no 13 sunny,hot,high,TRUE,no 14 sunny,mild,high,FALSE,no 15 sunny,cool,normal,FALSE,yes 16 sunny,mild,normal,TRUE,yes
(3)代码:
1 <?php 2 /** 3 * 构建决策树 4 */ 5 class buildDecisionTree { 6 const DECISION_TREE_CACHE = __DIR__ . '/decision_tree_serialize.'; //序列化决策树的存储路径 7 8 private $element; 9 private $sortElement; 10 private $resultKey; 11 private $decisionTree; 12 private $tariningDataPath; 13 14 public function build($tariningData) { 15 //文件是否存在 16 $this->tariningDataPath = __DIR__ . '/' . $tariningData; 17 if (!file_exists($this->tariningDataPath)) { 18 exit('tarining data file not exists'); 19 } 20 21 //判断是否可以直接从缓存中取 22 $crc32 = crc32(file_get_contents($this->tariningDataPath)); 23 $cachePath = self::DECISION_TREE_CACHE . $crc32; 24 if (file_exists($cachePath)) { 25 $cacheObj = unserialize(file_get_contents($cachePath)); 26 return $cacheObj; 27 } 28 29 //格式化数据 30 $trainingData = $this->formateData(); 31 32 //信息熵 33 $resArr = array_column($trainingData, $this->resultKey); 34 $resEntropy = $this->calEntropy($resArr); 35 36 //信息增益 37 $informationGain = $this->celInformationGain($trainingData, $resEntropy); 38 39 //倒序, 信息增益越大, 对结果影响越大, 越靠近根 40 arsort($informationGain); 41 $this->sortElement = array_keys($informationGain); 42 $this->decisionTree = $this->buildTree($this->sortElement, $trainingData); 43 44 //写缓存 45 $fp = fopen($cachePath, 'w'); 46 fwrite($fp, serialize($this)); 47 return $this; 48 } 49 50 //判断是否外出 51 public function isOut($argv) { 52 array_shift($argv); 53 if (count($argv) !== count($this->element)) { 54 exit('param error'); 55 } 56 57 //根据sortElement的顺序查, 策略树就是按照这个顺序构建的 58 $treeTmp = $this->decisionTree; 59 $argv = array_combine($this->element, $argv); 60 foreach ($this->sortElement as $e) { 61 $value = strtolower($argv[$e]); 62 if (!isset($treeTmp->nodes[$value])) { 63 exit ('make decision error:can\'t find node ' . $value); 64 } 65 $treeTmp = $treeTmp->nodes[$value]; 66 //找到确定值 67 if (!is_object($treeTmp)) { 68 return $treeTmp; 69 } 70 } 71 } 72 73 //构建决策树 74 private function buildTree($element, $trainingData) { 75 //当前因素结点 76 $currentElement = array_shift($element); 77 $tree = $this->demoFactory($currentElement); 78 $classifyData = $this->classify($trainingData, $currentElement); 79 foreach ($classifyData as $condition => $result) { 80 $sameRes = $this->isAllSame($result); 81 if ($sameRes === false) { 82 //筛选训练集之后,下面继续加节点 83 $nextTrainingData = $this->filteTrainingData($trainingData, $currentElement, $condition); 84 $tree->nodes[$condition] = $this->buildTree($element, $nextTrainingData); 85 } else { 86 //有确定值 87 $tree->nodes[$condition] = $sameRes; 88 } 89 } 90 return $tree; 91 } 92 93 //根据某一因素的某一情况筛选 94 private function filteTrainingData($trainingData, $element, $condition) { 95 $ret = array(); 96 foreach ($trainingData as $info) { 97 $info[$element] == $condition && $ret[] = $info; 98 } 99 return $ret; 100 } 101 102 //是否有确定结果(distinct是否为1) 103 private function isAllSame($result) { 104 $result = array_flip($result); 105 if (count($result) === 1) { 106 return key($result); 107 } 108 return false; 109 } 110 111 //生产一个简单的模型 112 private function demoFactory($element) { 113 $tree = new stdClass(); 114 $tree->nodes = array(); 115 $tree->element = $element; 116 return $tree; 117 } 118 119 //计算信息增益 120 private function celInformationGain($trainingData, $resEntropy) { 121 //信息增益 122 $informationGain = array(); 123 foreach ($this->element as $element) { 124 //分类 125 $classifyData = $this->classify($trainingData, $element); 126 //求熵 127 $entropyData = array_map([$this, 'calEntropy'], $classifyData); 128 //乘积加和 129 $product = array_map(function($c, $e) use ($trainingData) { 130 return count($c) / count($trainingData) * $e; 131 }, $classifyData, $entropyData); 132 $informationGain[$element] = $resEntropy - array_sum($product); 133 } 134 return $informationGain; 135 } 136 137 //分类 138 private function classify($data, $key) { 139 $classify = array(); 140 foreach ($data as $info) { 141 $val = $info[$key]; 142 $infoRes = $info[$this->resultKey]; 143 $classify[$val][] = $infoRes; 144 } 145 return $classify; 146 } 147 148 149 //格式化数据 150 private function formateData() { 151 //数据文件很小, 可以直接都读进来 152 $ret = array(); 153 $dataStr = file_get_contents($this->tariningDataPath); 154 $dataArr = array_filter(explode("\n", $dataStr)); 155 156 foreach ($dataArr as $info) { 157 $info = explode(',', strtolower($info)); 158 if (!$this->element) { 159 //第一行为每一列的标注, 其他为数据 160 $this->element = $info; 161 continue; 162 } 163 $ret[] = array_combine($this->element, $info); 164 } 165 166 //最后一列为结果值 167 $this->resultKey = array_pop($this->element); 168 return $ret; 169 } 170 171 //计算信息熵 172 private function calEntropy($res) { 173 $statistics = array(); 174 $count = count($res); 175 $entropy = 0; 176 foreach ($res as $v) { 177 !isset($statistics[$v]) && $statistics[$v] = 0; 178 $statistics[$v]++; 179 } 180 181 foreach ($statistics as $num) { 182 $quotient = $num / $count; 183 $entropy += $quotient * log($quotient, 2); 184 } 185 return abs($entropy); 186 } 187 } 188 189 //执行 190 echo (new buildDecisionTree())->build('WeatherTraining.csv')->isOut($argv);