用 Focal Loss 应对类别不平衡
用 Focal Loss 应对类别不平衡
当面临类别极度不平衡的数据时,标准的交叉熵损失会因大量易分样本的主导而失效。Focal Loss 通过引入动态调制因子,强制模型聚焦于训练过程中的“硬核”样本,是解决此类问题的关键技术。
1. 核心公式回顾
Focal Loss 在标准交叉熵的基础上,增加了权重因子 \(\\alpha\) 和调制因子 \((1-p\_t)^\gamma\)。其统一形式简洁而强大:
\[L_{\text{focal}} = -\alpha_t (1-p_t)^\gamma \log(p_t)
\]
其中:
- \(p\_t\) 定义为:当真实标签 \(y=1\) 时,\(p\_t=p\);当 \(y=0\) 时,\(p\_t=1-p\)。这里的 \(p\) 是模型预测为正类的概率。
- \(\\alpha\_t\) 是类别平衡参数,\(\\gamma\) 是聚焦参数,用于抑制易分样本的损失贡献。
2. 代码实现
我们将上述思想封装成一个健壮且易于使用的 Python 类。该实现直接接收模型输出的概率(经过 Sigmoid 激活),并通过 epsilon 裁剪保证数值计算的稳定性。
import numpy as np
from typing import Literal
class Focal_Loss:
"""
一个简洁且高效的 Focal Loss 实现,用于处理类别不平衡问题。
该实现接收模型输出的概率 (经过 Sigmoid 激活后) 作为输入。
"""
def __init__(self, alpha: float = 0.25, gamma: float = 2.0,
reduction: Literal['mean', 'sum', 'none'] = 'mean', eps: float = 1e-9):
"""
初始化 Focal Loss。
参数:
alpha (float): 平衡参数,用于调节正负样本的权重。
gamma (float): 聚焦参数,用于动态调整样本权重,使模型聚焦于难分样本。
reduction (str): 指定损失的聚合方式 ('mean', 'sum', 'none')。
eps (float): 一个极小的数值,用于避免 log(0) 导致的计算溢出。
"""
if reduction not in ['mean', 'sum', 'none']:
raise ValueError("reduction 必须是 'mean', 'sum', 或 'none'")
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.eps = eps
def __call__(self, y_pred_prob: np.ndarray, y_true: np.ndarray) -> np.ndarray | float:
"""
计算 Focal Loss。
参数:
y_pred_prob (np.ndarray): 模型预测的概率值,shape 为 (N, ...),值在 [0, 1] 之间。
y_true (np.ndarray): 真实标签,shape 与 y_pred_prob 相同,值为 0 或 1。
返回:
计算出的损失值。
"""
# 1. 裁剪概率值以保证数值稳定性
p = np.clip(y_pred_prob, self.eps, 1.0 - self.eps)
# 2. 计算正样本和负样本的损失项
# 正样本 (y=1) 的损失
loss_pos = -self.alpha * np.power(1 - p, self.gamma) * np.log(p)
# 负样本 (y=0) 的损失
loss_neg = -(1 - self.alpha) * np.power(p, self.gamma) * np.log(1 - p)
# 3. 根据真实标签选择对应的损失
# 利用 y_true (0或1) 作为掩码,优雅地合并两个损失项
per_sample_loss = y_true * loss_pos + (1 - y_true) * loss_neg
# 4. 根据指定的 reduction 策略聚合损失
if self.reduction == "mean":
return np.mean(per_sample_loss)
elif self.reduction == "sum":
return np.sum(per_sample_loss)
else: # self.reduction == 'none'
return per_sample_loss
浙公网安备 33010602011771号