机器学习CART算法与树剪枝

 

数据集下载地址:https://files-cdn.cnblogs.com/files/blogs/771845/Ch09-RegressionTrees.zip?t=1665989317

一、前言

本篇文章将会讲解CART算法的实现和树的剪枝方法,通过测试不同的数据集,学习CART算法和树剪枝技术。

二、将CART(Classification And Regression Trees)算法用于回归

在之前的文章,我们学习了决策树的原理和代码实现,使用使用决策树进行分类。决策树不断将数据切分成小数据集,直到所有目标标量完全相同,或者数据不能再切分为止。决策树是一种贪心算法,它要在给定时间内做出最佳选择,但不关心能否达到全局最优。

1、ID3算法的弊端

回忆一下,决策树的树构建算法是ID3。ID3的做法是每次选取当前最佳的特征来分割数据,并按照该特征的所有可能取值来切分。也就是说,如果一个特征有4种取值,那么数据将被切分成4份。一旦按某特征切分后,该特征在之后的算法执行过程中将不会再起作用,所以有观点认为这种切分方式过于迅速。

除了切分过于迅速外,ID3算法还存在另一个问题,它不能直接处理连续型特征。只有事先将连续型特征离散化,才能在ID3算法中使用。但这种转换过程会破坏连续型变量的内在特性。

2、CART算法

与ID3算法相反,CART算法正好适用于连续型特征。CART算法使用二元切分法来处理连续型变量。而使用二元切分法则易于对树构建过程进行调整以处理连续型特征。具体的处理方法是:如果特征值大于给定值就走左子树,否则就走右子树。

CART算法有两步:

  • 决策树生成:递归地构建二叉决策树的过程,基于训练数据集生成决策树,生成的决策树要尽量大;自上而下从根开始建立节点,在每个节点处要选择一个最好的属性来分裂,使得子节点中的训练集尽量的纯。不同的算法使用不同的指标来定义"最好":
  • 决策树剪枝:用验证数据集对已生成的树进行剪枝并选择最优子树,这时损失函数最小作为剪枝的标准。

决策树剪枝我们先不管,我们看下决策树生成。

在决策树的文章中,我们先根据信息熵的计算找到最佳特征切分数据集构建决策树。CART算法的决策树生成也是如此,实现过程如下:

  • 使用CART算法选择特征
  • 根据特征切分数据集合
  • 构建树

3、根据特征切分数据集合

我们先找软柿子捏,CART算法这里涉及到算法,实现起来复杂些,我们先挑个简单的,即根据特征切分数据集合。编写代码如下:

 

 
#-*- coding:utf-8 -*-
import numpy as np

def binSplitDataSet(dataSet, feature, value):
    """
    函数说明:根据特征切分数据集合
    Parameters:
        dataSet - 数据集合
        feature - 带切分的特征
        value - 该特征的值
    Returns:
        mat0 - 切分的数据集合0
        mat1 - 切分的数据集合1
    """
    mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:]
    mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0, mat1

if __name__ == '__main__':
    testMat = np.mat(np.eye(4))
    mat0, mat1 = binSplitDataSet(testMat, 1, 0.5)
    print('原始集合:\n', testMat)
    print('mat0:\n', mat0)
    print('mat1:\n', mat1)

 

 运行结果如下图所示:

 

 

我们先创建一个单位矩阵,然后根据切分规则,对数据矩阵进行切分。可以看到binSplitDataSet函数根据特定规则,对数据矩阵进行切分。

现在OK了,我们已经可以根据特征和特征值对数据进行切分了,mat0存放的是大于指定特征值的矩阵,mat1存放的是小于指定特征值的矩阵。接下来,我们就看看如何使用CART算法选择最佳分类特征。

4、CART算法

假设X与Y分别为输入和输出变量,并且Y是连续变量,给定训练数据集:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

其中,D表示整个数据集合,n为特征数。

一个回归树对应着输入空间(即特征空间)的一个划分以及在划分的单元上的输出值。假设已将输入空间划分为M个单元R1,R2,...Rm,并且在每个单元Rm上有一个固定的输出值Cm,于是回归树模型可表示为:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

这样就可以计算模型输出值与实际值的误差:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

我们希望每个单元上的Cm,可以是的这个平方误差最小化。易知,当Cm为相应单元的所有实际值的均值时,可以到最优:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

那么如何生成这些单元划分?

假设,我们选择变量 xj 为切分变量,它的取值 s 为切分点,那么就会得到两个区域:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

当j和s固定时,我们要找到两个区域的代表值c1,c2使各自区间上的平方差最小:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

前面已经知道c1,c2为区间上的平均:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

那么对固定的 j 只需要找到最优的s,然后通过遍历所有的变量,我们可以找到最优的j,这样我们就可以得到最优对(j,s),并得到两个区间。

这样的回归树通常称为最小二乘回归树(least squares regression tree)。

上述过程表示的算法步骤为:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

除此之外,我们再定义两个参数,tolS和tolN,分别用于控制误差变化限制和切分特征最少样本数。这两个参数的意义是什么呢?就是防止过拟合,提前设置终止条件,实际上是在进行一种所谓的预剪枝(prepruning)操作,在下一小节会进行进一步讲解。

先看下数据的分布情况,编写代码如下:

#-*- coding:utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np

def loadDataSet(fileName):
    """
    函数说明:加载数据
    Parameters:
        fileName - 文件名
    Returns:
        dataMat - 数据矩阵
    """
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float, curLine))                    #转化为float类型
        dataMat.append(fltLine)
    return dataMat

def plotDataSet(filename):
    """
    函数说明:绘制数据集
    Parameters:
        filename - 文件名
    Returns:
        无
    """
    dataMat = loadDataSet(filename)                                        #加载数据集
    n = len(dataMat)                                                    #数据个数
    xcord = []; ycord = []                                                #样本点
    for i in range(n):                                                    
        xcord.append(dataMat[i][0]); ycord.append(dataMat[i][1])        #样本点
    fig = plt.figure()
    ax = fig.add_subplot(111)                                            #添加subplot
    ax.scatter(xcord, ycord, s = 20, c = 'blue',alpha = .5)                #绘制样本点
    plt.title('DataSet')                                                #绘制title
    plt.xlabel('X')
    plt.show()

if __name__ == '__main__':
    filename = 'ex00.txt'
    plotDataSet(filename)

运行结果如下图所示:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

可以看到,这是一个很简单的数据集,我们先利用这个数据集测试我们的CART算法。

现在,编写代码如下:

posted @ 2022-10-17 14:31  宁宁真厉害  阅读(308)  评论(0)    收藏  举报