GlenTt

导航

回归任务的基石:MSE 损失函数理解与实现

回归任务的基石:MSE 损失函数详解与实现

在回归问题领域,均方误差 (Mean Squared Error, MSE) 是最常用、最直观的损失函数。它衡量的是模型预测值与真实值之间的差距。今天,我们就来深入探讨 MSE 的核心思想,并给出一个简洁的 NumPy 实现。

1. 核心公式

MSE 的定义非常清晰:它是预测值与真实值之差的平方的平均值。假设 \(y\_i\) 是第 \(i\) 个样本的真实值,$ \hat{y}_i $ 是模型对该样本的预测值,对于一个包含 \(n\) 个样本的批次,MSE 的计算公式为:

\[\text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 \]

2. 关键特性

为什么 MSE 如此流行?

  1. 惩罚大误差:由于误差是平方计算的,所以较大的误差(例如预测值与真实值相差 10)会比小的误差(相差 2)在损失中占据更大的权重 (\(10^2=100\) vs \(2^2=4\))。这使得模型在训练过程中会更倾向于修正那些离谱的预测点。
  2. 凸函数特性:MSE 损失函数是一个凸函数,这意味着它只有一个全局最小值,没有局部最小值。这极大地简化了优化过程,保证了梯度下降等算法能够稳定地收敛到最优解。
  3. 对离群点敏感:同样因为平方的特性,MSE 对数据中的离群点(Outliers)非常敏感。一个极端离群点的巨大误差在平方后会被不成比例地放大,可能会主导整个损失函数的梯度,导致模型性能下降。

3. 代码实现

import numpy as np
from typing import Literal

class MSE_Loss:
    """
    一个简洁的均方误差 (MSE) 损失函数实现。
    """
    def __init__(self, reduction: Literal['mean', 'sum', 'none'] = 'mean'):
        if reduction not in ['mean', 'sum', 'none']:
            raise ValueError("reduction 必须是 'mean', 'sum', 或 'none'")
        self.reduction = reduction
    
    def __call__(self, y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray | float:
        # 逐元素计算平方误差
        squared_error = (y_pred - y_true) ** 2

        # 根据指定策略聚合损失
        if self.reduction == "mean":
            return np.mean(squared_error)
        elif self.reduction == "sum":
            return np.sum(squared_error)
        else: # self.reduction == 'none'
            return squared_error

posted on 2025-09-14 10:50  GRITJW  阅读(155)  评论(0)    收藏  举报