XGBoost回归算法
刚开始探索理解机器学习算法时,我会被所有数学内容弄得不知所措。我发现,如果没有完全理解算法背后的直觉,就很难理解其中的数学原理。所以,我会倾向于那些将算法分解成更简单、更易理解的步骤。这就是我今天尝试做的事情,以一种即使10岁小孩也能理解的方式来解释XGBoost算法。开始吧!
让我们从训练数据集开始,这个数据集包含了5个样本。每个样本记录了他们的年龄(AGE)、是否有硕士学位(MASTER'S DEGREE),以及他们的工资(SALARY)(以千为单位),目标是使用XGBoost模型预测工资。
| AGE | MASTER'S DEGREE | SALARY |
|---|---|---|
| 25 | 是 | 80 |
| 32 | 否 | 75 |
| 23 | 是 | 60 |
| 26 | 否 | 60 |
| 28 | 是 | 75 |
步骤1:做一个初步预测并计算残差
这个预测可以是任何值。但假设初步预测是我们想要预测的变量平均值。
平均工资为:
(80 + 75 + 60 + 60 + 75) / 5 = 70
预测值 = 70(全部)
残差 = 观测值 - 预测值:
| 观测值 (SALARY) | 初始预测 | 残差 |
|---|---|---|
| 80 | 70 | 10 |
| 75 | 70 | 5 |
| 60 | 70 | -10 |
| 60 | 70 | -10 |
| 75 | 70 | 5 |
步骤2:构建XGBoost决策树
根节点相似度计算:
相似度评分(Similarity Score):
S = (Sum of Residuals)^2 / (Number of Residuals + λ)
= (0)^2 / (5 + 1)
= 0
尝试按"是否有硕士学位"进行拆分:
| 拆分 | 样本索引 | 残差 | 残差和 | 个数 | 相似度 |
|---|---|---|---|---|---|
| 有硕士学位(左) | 1,3,5 | 10,-10,5 | 5 | 3 | 25 / (3+1)=6.25 |
| 无硕士学位(右) | 2,4 | 5,-10 | -5 | 2 | 25 / (2+1)=8.33 |
总增益 = 左相似度 + 右相似度 - 根节点相似度 = 6.25 + 8.33 - 0 = 14.58
现在我们尝试年龄的拆分,先计算平均值:
年龄排序:23, 25, 26, 28, 32
平均值:
(23+25)/2 = 24
(25+26)/2 = 25.5
(26+28)/2 = 27
(28+32)/2 = 30
我们尝试年龄 < 23.5:
左节点:只有23,残差 = -10
右节点:其余残差 = 10, 5, -10, 5
增益计算类似,依次计算所有拆分:
最终发现:
“是否有硕士学位” 拆分具有最大增益,因此选择这个作为初始拆分。
接下来继续对左节点(有硕士学位)做拆分:
考虑年龄 < 25:
样本为 1,3,5:
| 年龄 | 残差 |
|---|---|
| 25 | 10 |
| 23 | -10 |
| 28 | 5 |
拆分为:
- 左:23(-10)→ 相似度 = 100/(1+1)=50
- 右:25,28 → 残差和=15 → 15²/(2+1)=75
总增益 = 50 + 75 - 6.25 = 118.75
右节点(无硕士)考虑年龄 < 24.5,样本为:
| 年龄 | 残差 |
|---|---|
| 32 | 5 |
| 26 | -10 |
增益 = (25 + 25) - 8.33 = 41.67
步骤3:树修剪
设 γ = 50
各拆分增益如下:
- 是否有硕士:14.58
- 年龄<25(左节点):118.75
- 年龄<24.5(右节点):41.67
因此保留前两个拆分,移除右节点的拆分。
最终树结构如下:
┌──────────────┐
│ 是否有硕士 │
└──────┬───────┘
│
┌─────────────┴──────────────┐
是 否
│ │
年龄 < 25 -
┌─────┴────┐
<25 >=25
步骤4:计算叶子节点值
计算公式:
叶节点输出 = Sum of Residuals / (Number of Residuals + λ)
- 叶子A(23岁): -10 / (1+1) = -5
- 叶子B(25,28岁):15 / (2+1) = 5
- 叶子C(无硕士):-5 / (2+1) = -1.67
步骤5:模型预测
预测公式:
新预测 = 初始预测 + 学习率 * 叶子输出值
假设学习率 η = 0.3:
| 样本 | 路径 | 输出值 | 新预测 |
|---|---|---|---|
| 1 | 是, >=25 | 5 | 70 + 0.3*5 = 71.5 |
| 2 | 否 | -1.67 | 70 - 0.3*1.67 ≈ 69.5 |
| 3 | 是, <25 | -5 | 70 - 0.3*5 = 68.5 |
| 4 | 否 | -1.67 | ≈ 69.5 |
| 5 | 是, >=25 | 5 | 71.5 |
步骤6:计算新预测值的残差
| 观测值 | 新预测 | 残差 |
|---|---|---|
| 80 | 71.5 | 8.5 |
| 75 | 69.5 | 5.5 |
| 60 | 68.5 | -8.5 |
| 60 | 69.5 | -9.5 |
| 75 | 71.5 | 3.5 |
步骤7:重复步骤2-6
我们重复构建新树,不断更新预测值和残差,直到残差极小或达到最大迭代次数。
预测总公式:
Pred = 初始预测 + η * Tree1(x) + η * Tree2(x) + … + η * Treen(x)
就是这样了!谢谢阅读,祝你在接下来的算法之旅中一切顺利!

浙公网安备 33010602011771号