XGBoost回归算法

刚开始探索理解机器学习算法时,我会被所有数学内容弄得不知所措。我发现,如果没有完全理解算法背后的直觉,就很难理解其中的数学原理。所以,我会倾向于那些将算法分解成更简单、更易理解的步骤。这就是我今天尝试做的事情,以一种即使10岁小孩也能理解的方式来解释XGBoost算法。开始吧!

让我们从训练数据集开始,这个数据集包含了5个样本。每个样本记录了他们的年龄(AGE)、是否有硕士学位(MASTER'S DEGREE),以及他们的工资(SALARY)(以千为单位),目标是使用XGBoost模型预测工资。

AGE MASTER'S DEGREE SALARY
25 80
32 75
23 60
26 60
28 75

步骤1:做一个初步预测并计算残差

这个预测可以是任何值。但假设初步预测是我们想要预测的变量平均值。

平均工资为:

(80 + 75 + 60 + 60 + 75) / 5 = 70

预测值 = 70(全部)

残差 = 观测值 - 预测值:

观测值 (SALARY) 初始预测 残差
80 70 10
75 70 5
60 70 -10
60 70 -10
75 70 5

步骤2:构建XGBoost决策树

根节点相似度计算:

相似度评分(Similarity Score):

S = (Sum of Residuals)^2 / (Number of Residuals + λ)
  = (0)^2 / (5 + 1)
  = 0

尝试按"是否有硕士学位"进行拆分:

拆分 样本索引 残差 残差和 个数 相似度
有硕士学位(左) 1,3,5 10,-10,5 5 3 25 / (3+1)=6.25
无硕士学位(右) 2,4 5,-10 -5 2 25 / (2+1)=8.33

总增益 = 左相似度 + 右相似度 - 根节点相似度 = 6.25 + 8.33 - 0 = 14.58

现在我们尝试年龄的拆分,先计算平均值:

年龄排序:23, 25, 26, 28, 32

平均值:
(23+25)/2 = 24
(25+26)/2 = 25.5
(26+28)/2 = 27
(28+32)/2 = 30

我们尝试年龄 < 23.5:

左节点:只有23,残差 = -10
右节点:其余残差 = 10, 5, -10, 5

增益计算类似,依次计算所有拆分:

最终发现:
“是否有硕士学位” 拆分具有最大增益,因此选择这个作为初始拆分。

接下来继续对左节点(有硕士学位)做拆分:

考虑年龄 < 25:

样本为 1,3,5:

年龄 残差
25 10
23 -10
28 5

拆分为:

  • 左:23(-10)→ 相似度 = 100/(1+1)=50
  • 右:25,28 → 残差和=15 → 15²/(2+1)=75

总增益 = 50 + 75 - 6.25 = 118.75

右节点(无硕士)考虑年龄 < 24.5,样本为:

年龄 残差
32 5
26 -10

增益 = (25 + 25) - 8.33 = 41.67

步骤3:树修剪

设 γ = 50

各拆分增益如下:

  • 是否有硕士:14.58
  • 年龄<25(左节点):118.75
  • 年龄<24.5(右节点):41.67

因此保留前两个拆分,移除右节点的拆分。

最终树结构如下:

                   ┌──────────────┐
                   │  是否有硕士  │
                   └──────┬───────┘
                          │
            ┌─────────────┴──────────────┐
         是                             否
        │                                │
  年龄 < 25                             -
   ┌─────┴────┐
 <25        >=25

步骤4:计算叶子节点值

计算公式:

叶节点输出 = Sum of Residuals / (Number of Residuals + λ)
  • 叶子A(23岁): -10 / (1+1) = -5
  • 叶子B(25,28岁):15 / (2+1) = 5
  • 叶子C(无硕士):-5 / (2+1) = -1.67

步骤5:模型预测

预测公式:

新预测 = 初始预测 + 学习率 * 叶子输出值

假设学习率 η = 0.3:

样本 路径 输出值 新预测
1 是, >=25 5 70 + 0.3*5 = 71.5
2 -1.67 70 - 0.3*1.67 ≈ 69.5
3 是, <25 -5 70 - 0.3*5 = 68.5
4 -1.67 ≈ 69.5
5 是, >=25 5 71.5

步骤6:计算新预测值的残差

观测值 新预测 残差
80 71.5 8.5
75 69.5 5.5
60 68.5 -8.5
60 69.5 -9.5
75 71.5 3.5

步骤7:重复步骤2-6

我们重复构建新树,不断更新预测值和残差,直到残差极小或达到最大迭代次数。

预测总公式:

Pred = 初始预测 + η * Tree1(x) + η * Tree2(x) + … + η * Treen(x)

就是这样了!谢谢阅读,祝你在接下来的算法之旅中一切顺利!

posted @ 2025-06-24 20:46  55open  阅读(36)  评论(0)    收藏  举报