GlenTt

导航

用 Focal Loss 应对类别不平衡

用 Focal Loss 应对类别不平衡

当面临类别极度不平衡的数据时,标准的交叉熵损失会因大量易分样本的主导而失效。Focal Loss 通过引入动态调制因子,强制模型聚焦于训练过程中的“硬核”样本,是解决此类问题的关键技术。

1. 核心公式回顾

Focal Loss 在标准交叉熵的基础上,增加了权重因子 \(\\alpha\) 和调制因子 \((1-p\_t)^\gamma\)。其统一形式简洁而强大:

\[L_{\text{focal}} = -\alpha_t (1-p_t)^\gamma \log(p_t) \]

其中:

  • \(p\_t\) 定义为:当真实标签 \(y=1\) 时,\(p\_t=p\);当 \(y=0\) 时,\(p\_t=1-p\)。这里的 \(p\) 是模型预测为正类的概率。
  • \(\\alpha\_t\) 是类别平衡参数,\(\\gamma\) 是聚焦参数,用于抑制易分样本的损失贡献。

2. 代码实现

我们将上述思想封装成一个健壮且易于使用的 Python 类。该实现直接接收模型输出的概率(经过 Sigmoid 激活),并通过 epsilon 裁剪保证数值计算的稳定性。

import numpy as np
from typing import Literal

class Focal_Loss:
    """
    一个简洁且高效的 Focal Loss 实现,用于处理类别不平衡问题。
    
    该实现接收模型输出的概率 (经过 Sigmoid 激活后) 作为输入。
    """
    def __init__(self, alpha: float = 0.25, gamma: float = 2.0, 
                 reduction: Literal['mean', 'sum', 'none'] = 'mean', eps: float = 1e-9):
        """
        初始化 Focal Loss。

        参数:
            alpha (float): 平衡参数,用于调节正负样本的权重。
            gamma (float): 聚焦参数,用于动态调整样本权重,使模型聚焦于难分样本。
            reduction (str): 指定损失的聚合方式 ('mean', 'sum', 'none')。
            eps (float): 一个极小的数值,用于避免 log(0) 导致的计算溢出。
        """
        if reduction not in ['mean', 'sum', 'none']:
            raise ValueError("reduction 必须是 'mean', 'sum', 或 'none'")
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.eps = eps
    
    def __call__(self, y_pred_prob: np.ndarray, y_true: np.ndarray) -> np.ndarray | float:
        """
        计算 Focal Loss。

        参数:
            y_pred_prob (np.ndarray): 模型预测的概率值,shape 为 (N, ...),值在 [0, 1] 之间。
            y_true (np.ndarray): 真实标签,shape 与 y_pred_prob 相同,值为 0 或 1。

        返回:
            计算出的损失值。
        """
        # 1. 裁剪概率值以保证数值稳定性
        p = np.clip(y_pred_prob, self.eps, 1.0 - self.eps)

        # 2. 计算正样本和负样本的损失项
        # 正样本 (y=1) 的损失
        loss_pos = -self.alpha * np.power(1 - p, self.gamma) * np.log(p)
        
        # 负样本 (y=0) 的损失
        loss_neg = -(1 - self.alpha) * np.power(p, self.gamma) * np.log(1 - p)

        # 3. 根据真实标签选择对应的损失
        # 利用 y_true (0或1) 作为掩码,优雅地合并两个损失项
        per_sample_loss = y_true * loss_pos + (1 - y_true) * loss_neg

        # 4. 根据指定的 reduction 策略聚合损失
        if self.reduction == "mean":
            return np.mean(per_sample_loss)
        elif self.reduction == "sum":
            return np.sum(per_sample_loss)
        else: # self.reduction == 'none'
            return per_sample_loss

posted on 2025-09-14 11:14  GRITJW  阅读(85)  评论(0)    收藏  举报