决策树系列(三)——ID3

预备知识:决策树 

初识ID3

      回顾决策树的基本知识,其构建过程主要有下述三个重要的问题:

     (1)数据是怎么分裂的

     (2)如何选择分类的属性

     (3)什么时候停止分裂

     从上述三个问题出发,以实际的例子对ID3算法进行阐述。

例:通过当天的天气、温度、湿度和季节预测明天的天气

                                  表1 原始数据

当天天气

温度

湿度

季节

明天天气

25

50

春天

21

48

春天

18

70

春天

28

41

夏天

8

65

冬天

18

43

夏天

24

56

秋天

18

76

秋天

31

61

夏天

6

43

冬天

15

55

秋天

4

58

冬天

 1.数据分割

      对于离散型数据,直接按照离散数据的取值进行分裂,每一个取值对应一个子节点,以“当前天气”为例对数据进行分割,如图1所示。

 

      对于连续型数据,ID3原本是没有处理能力的,只有通过离散化将连续性数据转化成离散型数据再进行处理。

      连续数据离散化是另外一个课题,本文不深入阐述,这里直接采用等距离数据划分的李算话方法。该方法先对数据进行排序,然后将连续型数据划分为多个区间,并使每一个区间的数据量基本相同,以温度为例对数据进行分割,如图2所示。

 

 2. 选择最优分裂属性

      ID3采用信息增益作为选择最优的分裂属性的方法,选择熵作为衡量节点纯度的标准,信息增益的计算公式如下:

                                               

      其中, 表示父节点的熵; 表示节点i的熵,熵越大,节点的信息量越多,越不纯; 表示子节点i的数据量与父节点数据量之比。 越大,表示分裂后的熵越小,子节点变得越纯,分类的效果越好,因此选择 最大的属性作为分裂属性。

      对上述的例子的跟节点进行分裂,分别计算每一个属性的信息增益,选择信息增益最大的属性进行分裂。

      天气属性:(数据分割如上图1所示) 

  

      温度:(数据分割如上图2所示)

     

      湿度:

 

      

      季节:

 

      

     由于最大,所以选择属性“季节”作为根节点的分裂属性。

 

3.停止分裂的条件

     停止分裂的条件已经在决策树中阐述,这里不再进行阐述。

     (1)最小节点数

  当节点的数据量小于一个指定的数量时,不继续分裂。两个原因:一是数据量较少时,再做分裂容易强化噪声数据的作用;二是降低树生长的复杂性。提前结束分裂一定程度上有利于降低过拟合的影响。

  (2)熵或者基尼值小于阀值。

     由上述可知,熵和基尼值的大小表示数据的复杂程度,当熵或者基尼值过小时,表示数据的纯度比较大,如果熵或者基尼值小于一定程度时,节点停止分裂。

  (3)决策树的深度达到指定的条件

   节点的深度可以理解为节点与决策树跟节点的距离,如根节点的子节点的深度为1,因为这些节点与跟节点的距离为1,子节点的深度要比父节点的深度大1。决策树的深度是所有叶子节点的最大深度,当深度到达指定的上限大小时,停止分裂。

  (4)所有特征已经使用完毕,不能继续进行分裂。

     被动式停止分裂的条件,当已经没有可分的属性时,直接将当前节点设置为叶子节点。

 

程序设计及源代码(C#版本)

(1)数据处理

       用二维数组存储原始的数据,每一行表示一条记录,前n-1列表示数据的属性,第n列表示分类的标签。

   static double[][] allData;

   为了方便后面的处理,对离散属性进行数字化处理,将离散值表示成数字,并用一个链表数组进行存储,数组的第一个元素表示属性1的离散值。

   static List<String>[] featureValues;

       那么经过处理后的表1数据可以转化为如表2所示的数据:

                                                                                表2 初始化后的数据

当天天气

温度

湿度

季节

明天天气

1

25

50

1

1

2

21

48

1

2

2

18

70

1

3

1

28

41

2

1

3

8

65

3

2

1

18

43

2

1

2

24

56

4

1

3

18

76

4

2

3

31

61

2

1

2

6

43

3

3

1

15

55

4

2

3

4

58

3

3

      其中,对于当天天气属性,数字{1,2,3}分别表示{晴,阴,雨};对于季节属性{1,2,3,4}分别表示{春天、夏天、冬天、秋天};对于明天天气{1,2,3}分别表示{晴、阴、雨}。

(2)两个类:节点类和分裂信息

  a)节点类Node

      该类存储了节点的信息,包括节点的数据量、节点选择的分裂属性、节点输出类、子节点的个数、子节点的分类误差等。

 1     class Node
 2     {
 3         /// <summary>
 4         /// 各个子节点的取值
 5         /// </summary>
 6         public List<String> features { get; set; }
 7         /// <summary>
 8         /// 分裂属性的类型
 9         /// </summary>
10         public String feature_Type { get; set; }
11         /// <summary>
12         /// 分裂的属性
13         /// </summary>
14         public String SplitFeature { get; set; }
15         /// <summary>
16         /// 节点对应各个分类的数目
17         /// </summary>
18         public double[] ClassCount { get; set; }
19         /// <summary>
20         /// 各个孩子节点
21         /// </summary>
22         public List<Node> childNodes { get; set; }
23         /// <summary>
24         /// 父亲节点(未用到)
25         /// </summary>
26         public Node Parent { get; set; }
27         /// <summary>
28         /// 占比最大的类别
29         /// </summary>
30         public String finalResult { get; set; }
31         /// <summary>
32         /// 数的深度
33         /// </summary>
34         public int deep { get; set; }
35         /// <summary>
36         /// 该节点占比最大的类标号
37         /// </summary>
38         public int result { get; set; }
39         /// <summary>
40         /// 节点的数量
41         /// </summary>
42         public int rowCount{ get; set; }
43 
44         
45         public void setClassCount(double[] count)
46         {
47             this.ClassCount = count;
48             double max = ClassCount[0];
49             int result = 0;
50             for (int i = 1; i < ClassCount.Length; i++)
51             {
52                 if (max < ClassCount[i])
53                 {
54                     max = ClassCount[i];
55                     result = i;
56                 }
57             }
58             //wrong = Convert.ToInt32(nums.Count - ClassCount[result]);
59             this.result = result;
60         }
61     }
View Code

  b)分裂信息类SplitInfo

      该类存储节点进行分裂的信息,包括各个子节点的行坐标、子节点各个类的数目、该节点分裂的属性、属性的类型等。

 1     class SplitInfo
 2     {
 3         /// <summary>
 4         /// 分裂的列下标
 5         /// </summary>
 6         public int splitIndex { get; set; }
 7         /// <summary>
 8         /// 数据类型
 9         /// </summary>
10         public int type { get; set; }
11         /// <summary>
12         /// 分裂属性的取值
13         /// </summary>
14         public List<String> features { get; set; }
15         /// <summary>
16         /// 各个节点的行坐标链表
17         /// </summary>
18         public List<int>[] temp { get; set; }
19         /// <summary>
20         /// 每个节点各类的数目
21         /// </summary>
22         public double[][] class_Count { get; set; }
23     }
View Code

(3)节点分裂方法findBestSplit(Node node,List<int> nums,int[] isUsed),该方法对节点进行分裂,返回值Node

其中:

    node表示即将进行分裂的节点;

    nums表示节点数据对应的行坐标列表;

    isUsed表示到该节点位置所有属性的使用情况(1:表示该属性不能再次使用,0:表示该属性可以使用);

findBestSplit主要有以下几个组成部分:

1)节点分裂停止的判定

判断节点是否需要继续分裂,分裂判断条件如上文所述。源代码如下:

  1         public static Object[] ifEnd(Node node, double entropy,int[] isUsed)
  2         {
  3             try
  4             {
  5                 double[] count = node.ClassCount;
  6                 int rowCount = node.rowCount;
  7                 int maxResult = 0;
  8                 double maxRate = 0;
  9                 #region 数达到某一深度
 10                 int deep = node.deep;
 11                 if (deep >= maxDeep)
 12                 {
 13                     maxResult = node.result + 1;
 14                     node.feature_Type=("result");
 15                     node.features=(new List<String>() { maxResult + "" });
 16                     return new Object[] { true, node };
 17                 }
 18                 #endregion
 19                 #region 纯度(其实跟后面的有点重了,记得要修改)
 20                 //maxResult = 1;
 21                 //for (int i = 1; i < count.Length; i++)
 22                 //{
 23                 //    if (count[i] / rowCount >= 0.95)
 24                 //    {
 25                 //        node.setFeatureType("result");
 26                 //        node.setFeatures(new List<String> { "" + (i + 1) });
 27                 //        return new Object[] { true, node };
 28                 //    }
 29                 //}
 30                 //node.setLeafWrong(rowCount - Convert.ToInt32(count[maxResult - 1]));
 31                 #endregion
 32                 #region 熵为0
 33                 if (entropy == 0)
 34                 {
 35                     maxRate = count[0] / rowCount;
 36                     maxResult = 1;
 37                     for (int i = 1; i < count.Length; i++)
 38                     {
 39                         if (count[i] / rowCount >= maxRate)
 40                         {
 41                             maxRate = count[i] / rowCount;
 42                             maxResult = i + 1;
 43                         }
 44                     }
 45                     node.feature_Type=("result");
 46                     node.features=(new List<String> { maxResult + "" });
 47                     return new Object[] { true, node };
 48                 }
 49                 #endregion
 50                 #region 属性已经分完
 51                 //int[] isUsed = node.;
 52                 bool flag = true;
 53                 for (int i = 0; i < isUsed.Length - 1; i++)
 54                 {
 55                     if (isUsed[i] == 0)
 56                     {
 57                         flag = false;
 58                         break;
 59                     }
 60                 }
 61                 if (flag)
 62                 {
 63                     maxRate = count[0] / rowCount;
 64                     maxResult = 1;
 65                     for (int i = 1; i < count.Length; i++)
 66                     {
 67                         if (count[i] / rowCount >= maxRate)
 68                         {
 69                             maxRate = count[i] / rowCount;
 70                             maxResult = i + 1;
 71                         }
 72                     }
 73                     node.feature_Type=("result");
 74                     node.features=(new List<String> { "" + (maxResult) });
 75                     return new Object[] { true, node };
 76                 }
 77                 #endregion
 78                 #region 数据量少于100
 79                 if (rowCount < Limit_Node)
 80                 {
 81                     maxRate = count[0] / rowCount;
 82                     maxResult = 1;
 83                     for (int i = 1; i < count.Length; i++)
 84                     {
 85                         if (count[i] / rowCount >= maxRate)
 86                         {
 87                             maxRate = count[i] / rowCount;
 88                             maxResult = i + 1;
 89                         }
 90                     }
 91                     node.feature_Type=("result");
 92                     node.features=(new List<String> { "" + (maxResult) });
 93                     return new Object[] { true, node };
 94                 }
 95                 #endregion
 96                 return new Object[] { false, node };
 97             }
 98             catch (Exception e)
 99             {
100                 return new Object[] { false, node };
101             }
102         }
View Code

2)寻找最优的分裂属性

寻找最优的分裂属性需要计算每一个分裂属性分裂后的信息增益,计算公式上文已给出,其中熵的计算代码如下:

 1         public static double CalEntropy(double[] counts, int countAll)
 2         {
 3             try
 4             {
 5                 double allShang = 0;
 6                 for (int i = 0; i < counts.Length; i++)
 7                 {
 8                     if (counts[i] == 0)
 9                     {
10                         continue;
11                     }
12                     double rate = counts[i] / countAll;
13                     allShang = allShang + rate * Math.Log(rate, 2);
14                 }
15                 return -allShang;
16             }
17             catch (Exception e)
18             {
19                 return 0;
20             }
21         }
View Code

3)进行分裂,同时子节点也执行相同的分类步骤

其实就是递归的过程,对每一个子节点执行findBestSplit方法进行分裂。

全部源代码:

  1         #region ID3核心算法
  2         /// <summary>
  3         /// 测试
  4         /// </summary>
  5         /// <param name="node"></param>
  6         /// <param name="data"></param>
  7         public static String findResult(Node node, String[] data)
  8         {
  9             List<String> featrues = node.features;
 10             String type = node.feature_Type;
 11             if (type == "result")
 12             {
 13                 return featrues[0];
 14             }
 15             int split = Convert.ToInt32(node.SplitFeature);
 16             List<Node> childNodes = node.childNodes;
 17             double[] resultCount = node.ClassCount;
 18             if (type == "连续")
 19             {
 20                 
 21 
 22                 for (int i = 0; i < featrues.Count; i++)
 23                 {
 24                     double value = Convert.ToDouble(featrues[i]);
 25                     if (Convert.ToDouble(data[split]) <= value)
 26                     {
 27                         return findResult(childNodes[i], data);
 28                     }
 29                 }
 30                 return findResult(childNodes[featrues.Count], data);
 31             }
 32             else
 33             {
 34                 for (int i = 0; i < featrues.Count; i++)
 35                 {
 36                     if (data[split] == featrues[i])
 37                     {
 38                         return findResult(childNodes[i], data);
 39                     }
 40                     if (i == featrues.Count - 1)
 41                     {
 42                         double count = resultCount[0];
 43                         int maxInt = 0;
 44                         for (int j = 1; j < resultCount.Length; j++)
 45                         {
 46                             if (count < resultCount[j])
 47                             {
 48                                 count = resultCount[j];
 49                                 maxInt = j;
 50                             }
 51                         }
 52                         return findResult(childNodes[0], data);
 53                     }
 54                 }
 55             }
 56             return null;
 57         }
 58         /// <summary>
 59         /// 判断是否还需要分裂
 60         /// </summary>
 61         /// <param name="node"></param>
 62         /// <returns></returns>
 63         public static Object[] ifEnd(Node node, double entropy,int[] isUsed)
 64         {
 65             try
 66             {
 67                 double[] count = node.ClassCount;
 68                 int rowCount = node.rowCount;
 69                 int maxResult = 0;
 70                 double maxRate = 0;
 71                 #region 数达到某一深度
 72                 int deep = node.deep;
 73                 if (deep >= maxDeep)
 74                 {
 75                     maxResult = node.result + 1;
 76                     node.feature_Type=("result");
 77                     node.features=(new List<String>() { maxResult + "" });
 78                     return new Object[] { true, node };
 79                 }
 80                 #endregion
 81                 #region 纯度(其实跟后面的有点重了,记得要修改)
 82                 //maxResult = 1;
 83                 //for (int i = 1; i < count.Length; i++)
 84                 //{
 85                 //    if (count[i] / rowCount >= 0.95)
 86                 //    {
 87                 //        node.setFeatureType("result");
 88                 //        node.setFeatures(new List<String> { "" + (i + 1) });
 89                 //        return new Object[] { true, node };
 90                 //    }
 91                 //}
 92                 //node.setLeafWrong(rowCount - Convert.ToInt32(count[maxResult - 1]));
 93                 #endregion
 94                 #region 熵为0
 95                 if (entropy == 0)
 96                 {
 97                     maxRate = count[0] / rowCount;
 98                     maxResult = 1;
 99                     for (int i = 1; i < count.Length; i++)
100                     {
101                         if (count[i] / rowCount >= maxRate)
102                         {
103                             maxRate = count[i] / rowCount;
104                             maxResult = i + 1;
105                         }
106                     }
107                     node.feature_Type=("result");
108                     node.features=(new List<String> { maxResult + "" });
109                     return new Object[] { true, node };
110                 }
111                 #endregion
112                 #region 属性已经分完
113                 //int[] isUsed = node.;
114                 bool flag = true;
115                 for (int i = 0; i < isUsed.Length - 1; i++)
116                 {
117                     if (isUsed[i] == 0)
118                     {
119                         flag = false;
120                         break;
121                     }
122                 }
123                 if (flag)
124                 {
125                     maxRate = count[0] / rowCount;
126                     maxResult = 1;
127                     for (int i = 1; i < count.Length; i++)
128                     {
129                         if (count[i] / rowCount >= maxRate)
130                         {
131                             maxRate = count[i] / rowCount;
132                             maxResult = i + 1;
133                         }
134                     }
135                     node.feature_Type=("result");
136                     node.features=(new List<String> { "" + (maxResult) });
137                     return new Object[] { true, node };
138                 }
139                 #endregion
140                 #region 数据量少于100
141                 if (rowCount < Limit_Node)
142                 {
143                     maxRate = count[0] / rowCount;
144                     maxResult = 1;
145                     for (int i = 1; i < count.Length; i++)
146                     {
147                         if (count[i] / rowCount >= maxRate)
148                         {
149                             maxRate = count[i] / rowCount;
150                             maxResult = i + 1;
151                         }
152                     }
153                     node.feature_Type=("result");
154                     node.features=(new List<String> { "" + (maxResult) });
155                     return new Object[] { true, node };
156                 }
157                 #endregion
158                 return new Object[] { false, node };
159             }
160             catch (Exception e)
161             {
162                 return new Object[] { false, node };
163             }
164         }
165         #region 排序算法
166         public static void InsertSort(double[] values, List<int> arr, int StartIndex, int endIndex)
167         {
168             for (int i = StartIndex + 1; i <= endIndex; i++)
169             {
170                 int key = arr[i];
171                 double init = values[i];
172                 int j = i - 1;
173                 while (j >= StartIndex && values[j] > init)
174                 {
175                     arr[j + 1] = arr[j];
176                     values[j + 1] = values[j];
177                     j--;
178                 }
179                 arr[j + 1] = key;
180                 values[j + 1] = init;
181             }
182         }
183         static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high)
184         {
185             int mid = low + ((high - low) >> 1);//计算数组中间的元素的下标  
186 
187             //使用三数取中法选择枢轴  
188             if (values[mid] > values[high])//目标: arr[mid] <= arr[high]  
189             {
190                 swap(values, arr, mid, high);
191             }
192             if (values[low] > values[high])//目标: arr[low] <= arr[high]  
193             {
194                 swap(values, arr, low, high);
195             }
196             if (values[mid] > values[low]) //目标: arr[low] >= arr[mid]  
197             {
198                 swap(values, arr, mid, low);
199             }
200             //此时,arr[mid] <= arr[low] <= arr[high]  
201             return low;
202             //low的位置上保存这三个位置中间的值  
203             //分割时可以直接使用low位置的元素作为枢轴,而不用改变分割函数了  
204         }
205         static void swap(double[] values, List<int> arr, int t1, int t2)
206         {
207             double temp = values[t1];
208             values[t1] = values[t2];
209             values[t2] = temp;
210             int key = arr[t1];
211             arr[t1] = arr[t2];
212             arr[t2] = key;
213         }
214         static void QSort(double[] values, List<int> arr, int low, int high)
215         {
216             int first = low;
217             int last = high;
218 
219             int left = low;
220             int right = high;
221 
222             int leftLen = 0;
223             int rightLen = 0;
224 
225             if (high - low + 1 < 10)
226             {
227                 InsertSort(values, arr, low, high);
228                 return;
229             }
230 
231             //一次分割 
232             int key = SelectPivotMedianOfThree(values, arr, low, high);//使用三数取中法选择枢轴 
233             double inti = values[key];
234             int currentKey = arr[key];
235 
236             while (low < high)
237             {
238                 while (high > low && values[high] >= inti)
239                 {
240                     if (values[high] == inti)//处理相等元素  
241                     {
242                         swap(values, arr, right, high);
243                         right--;
244                         rightLen++;
245                     }
246                     high--;
247                 }
248                 arr[low] = arr[high];
249                 values[low] = values[high];
250                 while (high > low && values[low] <= inti)
251                 {
252                     if (values[low] == inti)
253                     {
254                         swap(values, arr, left, low);
255                         left++;
256                         leftLen++;
257                     }
258                     low++;
259                 }
260                 arr[high] = arr[low];
261                 values[high] = values[low];
262             }
263             arr[low] = currentKey;
264             values[low] = values[key];
265             //一次快排结束  
266             //把与枢轴key相同的元素移到枢轴最终位置周围  
267             int i = low - 1;
268             int j = first;
269             while (j < left && values[i] != inti)
270             {
271                 swap(values, arr, i, j);
272                 i--;
273                 j++;
274             }
275             i = low + 1;
276             j = last;
277             while (j > right && values[i] != inti)
278             {
279                 swap(values, arr, i, j);
280                 i++;
281                 j--;
282             }
283             QSort(values, arr, first, low - 1 - leftLen);
284             QSort(values, arr, low + 1 + rightLen, last);
285         }
286         #endregion
287         /// <summary>
288         /// 寻找最佳的分裂点
289         /// </summary>
290         /// <param name="num"></param>
291         /// <param name="node"></param>
292         public static Node findBestSplit(Node node, int lastCol,List<int> nums,int[] isUsed)
293         {
294             try
295             {
296                 //判断是否继续分裂
297                 double totalShang = CalEntropy(node.ClassCount, nums.Count);
298                 Object[] check = ifEnd(node, totalShang, isUsed);
299                 if ((bool)check[0])
300                 {
301                     node = (Node)check[1];
302                     return node;
303                 }
304                 #region 变量声明
305                 SplitInfo info = new SplitInfo();
306                 //int[] isUsed = node.getUsed();              //连续变量or离散变量
307                 //List<int> nums = node.getNum();             //样本的标号
308                 int RowCount = nums.Count;                  //样本总数
309                 double jubuMax = 0;                         //局部最大熵
310                 #endregion
311                 for (int i = 0; i < isUsed.Length - 1; i++)
312                 {
313                     if (isUsed[i] == 1)
314                     {
315                         continue;
316                     }
317                     #region 离散变量
318                     if (type[i] == 0)
319                     {
320                         int[] allFeatureCount = new int[0];         //所有类别的数量
321                         double[][] allCount = new double[allNum[i]][];
322                         for (int j = 0; j < allCount.Length; j++)
323                         {
324                             allCount[j] = new double[classCount];
325                         }
326                         int[] countAllFeature = new int[allNum[i]];
327                         List<int>[] temp = new List<int>[allNum[i]];
328                         for (int j = 0; j < temp.Length; j++)
329                         {
330                             temp[j] = new List<int>();
331                         }
332                         for (int j = 0; j < nums.Count; j++)
333                         {
334                             int index = Convert.ToInt32(allData[nums[j]][i]);
335                             temp[index - 1].Add(nums[j]);
336                             countAllFeature[index - 1]++;
337                             allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
338                         }
339                         double allShang = 0;
340                         for (int j = 0; j < allCount.Length; j++)
341                         {
342                             allShang = allShang + CalEntropy(allCount[j], countAllFeature[j]) * countAllFeature[j] / RowCount;
343                         }
344                         allShang = (totalShang - allShang);
345                         if (allShang > jubuMax)
346                         {
347                             info.features=new List<String>();
348                             info.type=0;
349                             info.temp=(temp);
350                             info.splitIndex=(i);
351                             info.class_Count=(allCount);
352                             jubuMax = allShang;
353                             allFeatureCount = countAllFeature;
354                         }
355                     }
356                     #endregion
357                     #region 连续变量
358                     else
359                     {
360                         double[] leftCount = new double[classCount];          //做节点各个类别的数量
361                         double[] rightCount = new double[classCount];        //右节点各个类别的数量
362                         double[] values = new double[nums.Count];
363                         List<String> List_Feature = new List<string>();
364                         for (int j = 0; j < values.Length; j++)
365                         {
366                             values[j] = allData[nums[j]][i];
367                         }
368                         QSort(values, nums, 0, nums.Count - 1);
369                         int eachNum = nums.Count / 5;
370                         double lianxuMax = 0;                                   //连续型属性的最大熵
371                         int index = 1;
372                         double[][] counts = new double[5][];
373                         List<int>[] temp = new List<int>[5];
374                         for (int j = 0; j < 5; j++)
375                         {
376                             counts[j] = new double[classCount];
377                             temp[j] = new List<int>();
378                         }
379                         for (int j = 0; j < nums.Count - 1; j++)
380                         {
381                             if (j >= index * eachNum&&index<5)
382                             {
383                                 List_Feature.Add(allData[nums[j]][i]+"");
384                                 lianxuMax += eachNum*CalEntropy(counts[index - 1], eachNum)/RowCount;
385                                 index++;
386                             }
387                             temp[index-1].Add(nums[j]);
388                             counts[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1])-1]++;
389                         }
390                         lianxuMax += ((eachNum + nums.Count % 5)*CalEntropy(counts[index - 1], eachNum + nums.Count % 5) / RowCount);
391                         lianxuMax = totalShang - lianxuMax;
392                         if (lianxuMax > jubuMax)
393                         {
394                             info.splitIndex=(i);
395                             info.features=(List_Feature);
396                             info.type=(1);
397                             jubuMax = lianxuMax;
398                             info.temp=(temp);
399                             info.class_Count=(counts);
400                         }
401                     }
402                     #endregion
403                 }
404                 #region 如何找不到最佳的分裂属性,则设为叶节点
405                 if (info.splitIndex == -1)
406                 {
407                     double[] finalCount = node.ClassCount;
408                     double max = finalCount[0];
409                     int result = 1;
410                     for (int i = 1; i < finalCount.Length; i++)
411                     {
412                         if (finalCount[i] > max)
413                         {
414                             max = finalCount[i];
415                             result = (i + 1);
416                         }
417                     }
418                     node.feature_Type=("result");
419                     node.features=(new List<String> { "" + result });
420                     return node;
421                 }
422                 #endregion
423                 int deep = node.deep;
424                 #region 分裂
425                 node.SplitFeature=("" + info.splitIndex);
426                 
427                 List<Node> childNode = new List<Node>();
428                 int[] used = new int[isUsed.Length];
429                 for (int i = 0; i < used.Length; i++)
430                 {
431                     used[i] = isUsed[i];
432                 }
433                 if (info.type == 0)
434                 {
435                     used[info.splitIndex] = 1;
436                     node.feature_Type=("离散");
437                 }
438                 else
439                 {
440                     used[info.splitIndex] = 0;
441                     node.feature_Type=("连续");
442                 }
443                 int sumLeaf = 0;
444                 int sumWrong = 0;
445                 List<int>[] rowIndex = info.temp;
446                 List<String> features = info.features;
447                 for (int j = 0; j < rowIndex.Length; j++)
448                 {
449                     if (rowIndex[j].Count == 0)
450                     {
451                         continue;
452                     }
453                     if (info.type == 0)
454                         features.Add(""+(j+1));
455                     Node node1 = new Node();
456                     //node1.setNum(info.getTemp()[j]);
457                     node1.setClassCount(info.class_Count[j]);
458                     //node1.setUsed(used);
459                     node1.deep=(deep + 1);
460                     node1.rowCount = info.temp[j].Count;
461                     node1 = findBestSplit(node1, info.splitIndex,info.temp[j], used);
462                     childNode.Add(node1);
463                 }
464                 node.features=(features);
465                 node.childNodes=(childNode);
466                 
467                 #endregion
468                 return node;
469             }
470             catch (Exception e)
471             {
472                 Console.WriteLine(e.StackTrace);
473                 return node;
474             }
475         }
476         /// <summary>
477         /// 计算熵
478         /// </summary>
479         /// <param name="counts"></param>
480         /// <param name="countAll"></param>
481         /// <returns></returns>
482         public static double CalEntropy(double[] counts, int countAll)
483         {
484             try
485             {
486                 double allShang = 0;
487                 for (int i = 0; i < counts.Length; i++)
488                 {
489                     if (counts[i] == 0)
490                     {
491                         continue;
492                     }
493                     double rate = counts[i] / countAll;
494                     allShang = allShang + rate * Math.Log(rate, 2);
495                 }
496                 return -allShang;
497             }
498             catch (Exception e)
499             {
500                 return 0;
501             }
502         }
503         #endregion
View Code

 (注:上述代码只是ID3的核心代码,数据预处理的代码并没有给出,只要将预处理后的数据输入到主方法findBestSplit中,就可以得到最终的结果)

总结

     ID3是基本的决策树构建算法,作为决策树经典的构建算法,其具有结构简单、清晰易懂的特点。虽然ID3比较灵活方便,但是有以下几个缺点:

 (1)采用信息增益进行分裂,分裂的精确度可能没有采用信息增益率进行分裂高

   (2)不能处理连续型数据,只能通过离散化将连续性数据转化为离散型数据

   (3)不能处理缺省值

   (4)没有对决策树进行剪枝处理,很可能会出现过拟合的问题

posted on 2016-01-03 14:38  学会分享~  阅读(29921)  评论(3编辑  收藏  举报

导航