高效实现 BCE Loss:从理论到数值稳定的代码
高效实现 BCE Loss:从理论到数值稳定的代码
在任何二分类任务中,二元交叉熵 (Binary Cross-Entropy, BCE) 损失函数都是基石。然而,一个看似简单的公式背后,却隐藏着数值计算的陷阱。今天,我们直击要点,讲解如何从理论公式演进到工业级的稳定代码实现。
1. 基础 BCE 公式
BCE 的标准形式是基于概率定义的。假设 \(y\) 为真实标签 (0 或 1),\(p\) 为模型预测为 1 的概率(即 Sigmoid 的输出),则损失 \(L\) 为:
\[L = -[y \cdot \log(p) + (1-y) \cdot \log(1-p)]
\]
这个公式非常直观,但存在一个致命缺陷:当 \(p\) 趋近于 0 或 1 时,\(\\log(p)\) 或 \(\\log(1-p)\) 会导致 log(0),产生负无穷大的结果,这在计算上是灾难性的。
2. 数值稳定的 BCE 公式
为了解决这个问题,我们不应该使用模型输出的概率 \(p\),而应直接使用未经 Sigmoid 激活的原始输出——Logits,我们记为 \(x\)。
通过将 \(p = \\text{sigmoid}(x)\) 代入原始公式并进行数学推导与化简,我们可以得到一个等价且高度数值稳定的形式:
\[L = \max(x, 0) - x \cdot y + \log(1 + e^{-|x|})
\]
这个公式的核心优势在于,指数项中的 \(-|x|\) 永远是非正数,这使得 \(e^{-|x|}\) 的结果被限制在 \((0, 1]\) 区间内,从根本上避免了浮点数上溢的风险。它不依赖任何 epsilon 裁剪之类的技巧,是数值计算上的最优解。
3. 代码实现
现在,我们将这个强大的公式封装成一个简洁、高效的 Python 类。
import numpy as np
from typing import Literal
class BCE_Logits_Loss:
"""
一个数值稳定的、基于 Logits 的二元交叉熵损失函数实现。
"""
def __init__(self, reduction: Literal['mean', 'sum', 'none'] = 'mean'):
if reduction not in ['mean', 'sum', 'none']:
raise ValueError("reduction 必须是 'mean', 'sum', 或 'none'")
self.reduction = reduction
def __call__(self, y_pred_logits: np.ndarray, y_true: np.ndarray) -> np.ndarray | float:
x = y_pred_logits
y = y_true
# 直接套用数值稳定的BCE公式
per_sample_loss = np.maximum(x, 0) - x * y + np.log(1 + np.exp(-np.abs(x)))
if self.reduction == "mean":
return np.mean(per_sample_loss)
elif self.reduction == "sum":
return np.sum(per_sample_loss)
else: # self.reduction == 'none'
return per_sample_loss
浙公网安备 33010602011771号