实用指南:PyTorch 损失函数与激活函数的正确组合

前言

你是否遇到过这些困惑?

  • 为什么用了 CrossEntropyLoss 手动加 Softmax,结果 loss 不降反升?
  • BCELossBCEWithLogitsLoss 到底有什么区别?
  • 多分类和多标签分类,损失函数该怎么选?
  • target 到底该传索引还是 one-hot?

如果你也被这些问题困扰过,那这篇文章就是为你准备的。

本文将从原理到实践,用数据流图 + 数学公式 + 代码示例三位一体的方式,彻底讲清楚 PyTorch 中损失函数与激活函数的搭配规则。读完这篇文章,你将能够:

  1. 根据任务类型,快速选择正确的损失函数
  2. 理解每种损失函数的内部计算过程
  3. 避开常见的踩坑点
  4. 拿到可直接复用的代码模板

组合推荐

任务类型推荐组合predict 形状target 形状/类型
多分类(N选1)LinearCrossEntropyLoss[N, C][N] LongTensor
二分类LinearBCEWithLogitsLoss[N, 1][N, 1] FloatTensor
多标签(N选M)LinearBCEWithLogitsLoss[N, L][N, L] FloatTensor
回归LinearMSELoss[N, 1][N, 1] FloatTensor

注意:

  1. CrossEntropyLoss 内置 Softmax,不要手动加
  2. BCEWithLogitsLoss 内置 Sigmoid,不要手动加
  3. CrossEntropyLoss 的 target 是类别索引,不是 one-hot

一、快速查表

在深入原理之前,先给大家一个速查表。遇到具体场景时,可以直接来这里查。

按场景查组合

我要做什么?用什么损失函数?
图像分类(ImageNet 1000类)CrossEntropyLoss
文本情感分析(正面/负面)BCEWithLogitsLoss 或 CrossEntropyLoss
垃圾邮件检测(是/否)BCEWithLogitsLoss
文章打标签(可多选:科技/体育/娱乐)BCEWithLogitsLoss
预测房价(连续值)MSELoss
目标检测框回归SmoothL1Loss
模型蒸馏KLDivLoss
正负样本极度不平衡(1:1000)FocalLoss
人脸识别/图像检索TripletMarginLoss
语音识别/OCRCTCLoss

按损失函数查用法

损失函数适用场景是否内置激活predicttarget
CrossEntropyLoss多分类(互斥)Softmax[N,C] logits[N] 索引
BCEWithLogitsLoss二分类/多标签Sigmoid[N,1][N,L]同形状,float
BCELoss二分类(不推荐)无,需手动[N,1] 概率同形状,float
MSELoss回归[N,1][N,1]
L1Loss回归(抗异常值)[N,1][N,1]
SmoothL1Loss目标检测bbox[N,1][N,1]

二、多分类任务:CrossEntropyLoss

这是最常用的分类损失函数,适用于互斥的多分类场景(每个样本只属于一个类别)。

2.1 正确用法

import torch
import torch.nn as nn
class MultiClassModel(nn.Module):
def __init__(self, input_size, num_classes):
super().__init__()        # 线性层直接输出,不加激活函数!  
self.fc = nn.Linear(input_size, num_classes)
self.loss = nn.CrossEntropyLoss()
def forward(self, x, target=None):
predict = self.fc(x)  # [N, C] 原始分数(logits)  
if target is not None:
return self.loss(predict, target)  # target: [N]        
return predict
# 推理时获取预测类别  
probs = torch.softmax(predict, dim=1)
pred_class = torch.argmax(probs, dim=1)

2.2 数据格式

predict: [N, C] → 每个样本对应 C 个类别的原始分数(logits)
target: [N] → 每个样本的真实类别索引 (0 ~ C-1)
注意:target 是索引,不是 one-hot 编码!

2.3 内部计算过程详解

很多初学者不理解 CrossEntropyLoss 内部到底做了什么。我们用一个具体例子来拆解:

输入数据:
predict = [[2.0, 1.0, 0.1]] # 1个样本,3分类的 logits target = [0] # 真实类别是第0类

┌─────────────────────────────────────────────────────────────┐
│  Step 1: Softmax 将 logits 转为概率分布                       │├─────────────────────────────────────────────────────────────┤
│                                                             │
│   e^2.0 = 7.39                                              │
│   e^1.0 = 2.72                                              │
│   e^0.1 = 1.11                                              │
│   ──────────                                                │
│   sum   = 11.22                                             │
│                                                             │
│   p[0] = 7.39 / 11.22 = 0.66                                │
│   p[1] = 2.72 / 11.22 = 0.24                                │
│   p[2] = 1.11 / 11.22 = 0.10                                │
│                                                             │
│   概率分布 p = [0.66, 0.24, 0.10]                            │└─────────────────────────────────────────────────────────────┘
                              ↓┌─────────────────────────────────────────────────────────────┐
│  Step 2: 取真实类别的概率,计算负对数                          │├─────────────────────────────────────────────────────────────┤
│                                                             │
│   target = 0,所以取 p[0] = 0.66                             ││                                                             │
│   Loss = -log(0.66) = 0.42                                  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

公式总结:

CrossEntropyLoss=−log⁡(eztarget∑jezj)\text{CrossEntropyLoss} = -\log\left(\frac{e^{z_{target}}}{\sum_{j} e^{z_j}}\right)CrossEntropyLoss=log(jezjeztarget)

2.4 为什么这样设计有效?

让我们看看不同预测情况下的 loss 值:

预测情况真实类别的概率Loss = -log§模型更新幅度
预测正确,置信度高0.950.05几乎不更新
预测正确,置信度低0.600.51适度更新
预测错误0.102.30大幅更新

核心思想:损失函数的梯度会推动模型,让正确类别的分数不断变大!

2.6 为什么 target 是索引而不是 one-hot?

原始交叉熵公式是这样的:

L=−∑i=0C−1yi⋅log⁡(pi)L = -\sum_{i=0}^{C-1} y_i \cdot \log(p_i)L=i=0C1yilog(pi)

其中 yyy 是 one-hot 编码,如 [0, 1, 0]

展开后:
L=−(0×log⁡(p0)+1×log⁡(p1)+0×log⁡(p2))L = -(0 \times \log(p_0) + 1 \times \log(p_1) + 0 \times \log(p_2))L=(0×log(p0)+1×log(p1)+0×log(p2))

发现问题没?只有一项是有效的! 因为 one-hot 中只有一个 1,其他都是 0。

所以可以简化为:
L=−log⁡(ptarget)L = -\log(p_{target})L=log(ptarget)

结论:只需要知道真实类别的索引即可,不需要完整的 one-hot,更省内存也更高效!


三、二分类任务:BCEWithLogitsLoss

二分类是多分类的特例,但有更高效的实现方式。

3.1 方式一:当作2类多分类

self.fc = nn.Linear(hidden_size, 2)  # 输出2个分数  
self.loss = nn.CrossEntropyLoss()
# predict: [N, 2], target: [N] (值为0或1)  

这种方式可行,但输出维度多了一个,不够高效。

3.2 方式二:BCEWithLogitsLoss(推荐)

class BinaryClassModel(nn.Module):
def __init__(self, input_size):
super().__init__()
self.fc = nn.Linear(input_size, 1)  # 只输出1个分数  
self.loss = nn.BCEWithLogitsLoss()
def forward(self, x, target=None):
predict = self.fc(x)  # [N, 1] 原始分数(logits)  
if target is not None:
return self.loss(predict, target.float())  # target 必须是 float        
return torch.sigmoid(predict)  # 推理时转概率  
# 推理时  
probs = torch.sigmoid(predict)
pred_class = (probs > 0.5).long()

3.3 数据流图示

predict: [N, 1]           target: [N, 1]
     ↓                         ↓
┌─────────┐               ┌─────────┐
│   2.0   │  样本1        │   1.0   │  正样本
├─────────┤               ├─────────┤
│  -1.5   │  样本2         │   0.0   │  负样本
├─────────┤               ├─────────┤
│   0.5   │  样本3         │   1.0   │  正样本
└─────────┘                └─────────┘
     ↓  Sigmoid(内置,不需要手动加!)
     ↓
┌─────────┐
│  0.88   │  样本1 → L = -log(0.88) = 0.13
├─────────┤
│  0.18   │  样本2 → L = -log(1-0.18) = -log(0.82) = 0.20
├─────────┤
│  0.62   │  样本3 → L = -log(0.62) = 0.48
└─────────┘
     ↓  Loss = (0.13 + 0.20 + 0.48) / 3 = 0.27```

3.4 数学原理

BCEWithLogitsLoss = Sigmoid + BCELoss

步骤1:Sigmoid 将分数转为概率
p=σ(z)=11+e−zp = \sigma(z) = \frac{1}{1 + e^{-z}}p=σ(z)=1+ez1

步骤2:二元交叉熵
L=−[y⋅log⁡(p)+(1−y)⋅log⁡(1−p)]L = -[y \cdot \log(p) + (1-y) \cdot \log(1-p)]L=[ylog(p)+(1y)log(1p)]

公式展开理解:

  • y = 1(正样本)时:L=−log⁡(p)L = -\log(p)L=log(p) → 希望 p 接近 1
  • y = 0(负样本)时:L=−log⁡(1−p)L = -\log(1-p)L=log(1p) → 希望 p 接近 0

3.5 为什么用 BCEWithLogitsLoss 而不是 Sigmoid + BCELoss?

很多人会问:我手动加 Sigmoid 再用 BCELoss 不行吗?

可以,但不推荐。 原因有三:

  1. 数值稳定性:当 sigmoid(x) 接近 0 或 1 时,log(sigmoid(x)) 会产生数值问题。BCEWithLogitsLoss 内部用 log-sum-exp 技巧避免了这个问题。

  2. 计算效率:一次前向传播完成,不需要中间存储 sigmoid 结果。

  3. 梯度更稳定:避免 sigmoid 饱和区的梯度消失问题。


四、多标签分类:BCEWithLogitsLoss

4.1 什么是多标签分类?

多标签 ≠ 多分类!

多分类多标签
定义N选1N选M
例子这张图是猫还是狗?这张图里有猫、有狗、有人?
激活函数Softmax(概率和=1)Sigmoid(每个标签独立)
损失函数CrossEntropyLossBCEWithLogitsLoss

4.2 正确用法

class MultiLabelModel(nn.Module):
def __init__(self, input_size, num_labels):
super().__init__()
self.fc = nn.Linear(input_size, num_labels)
self.loss = nn.BCEWithLogitsLoss()
def forward(self, x, target=None):
predict = self.fc(x)  # [N, L] 每个标签的分数  
if target is not None:
return self.loss(predict, target.float())  # target: [N, L] 多热编码  
return torch.sigmoid(predict)
# 推理时  
probs = torch.sigmoid(predict)
pred_labels = (probs > 0.5).long()  # 每个位置独立判断  

4.3 数据格式

predict: [N, L]  →  L 个标签的原始分数
target:  [N, L]  →  多热编码(multi-hot),如 [1, 0, 1, 0, 1]

4.4 数据流图示

predict: [N, L]              target: [N, L]
     ↓                            ↓
┌────────────────┐          ┌────────────────┐
│ 2.0  -1.0  0.5 │ 样本1     │  1    0    1   │ 标签0和2为正
└────────────────┘          └────────────────┘
     ↓  Sigmoid(对每个位置独立计算)
     ↓
┌────────────────┐
│ 0.88 0.27 0.62 │
└────────────────┘
     ↓
 每个位置独立计算BCE:
  位置0: target=1, p=0.88 → -log(0.88) = 0.13
  位置1: target=0, p=0.27 → -log(1-0.27) = 0.31
  位置2: target=1, p=0.62 → -log(0.62) = 0.48
     ↓
 Loss = (0.13 + 0.31 + 0.48) / 3 = 0.31

五、回归任务

回归任务预测的是连续值,不需要激活函数,线性层直接输出即可。

5.1 MSELoss(均方误差)

最常用的回归损失函数。

class RegressionModel(nn.Module):
def __init__(self, input_size):
super().__init__()
self.fc = nn.Linear(input_size, 1)
self.loss = nn.MSELoss()
def forward(self, x, target=None):
predict = self.fc(x)  # [N, 1] 直接输出数值  
if target is not None:
return self.loss(predict, target)
return predict

公式:
L=1N∑i=1N(yi−y^i)2L = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2L=N1i=1N(yiy^i)2

特点:

  • ✅ 处处可导,优化稳定
  • ✅ 大误差时梯度大,收敛快
  • ❌ 对异常值敏感(平方会放大误差)

5.2 L1Loss(平均绝对误差)

对异常值更鲁棒。

公式:
L=1N∑i=1N∣yi−y^i∣L = \frac{1}{N} \sum_{i=1}^{N} |y_i - \hat{y}_i|L=N1i=1Nyiy^i

特点:

  • ✅ 对异常值鲁棒(不会放大大误差)
  • ❌ 在 0 点不可导
  • ❌ 梯度恒定,收敛可能慢

5.3 SmoothL1Loss(Huber Loss)

MSE 和 L1 的结合体,目标检测中的标配。

L={0.5(y−y^)2if ∣y−y^∣<1∣y−y^∣−0.5otherwiseL = \begin{cases} 0.5(y - \hat{y})^2 & \text{if } |y - \hat{y}| < 1 \\ |y - \hat{y}| - 0.5 & \text{otherwise} \end{cases}L={0.5(yy^)2yy^0.5if yy^<1otherwise

5.4 回归损失对比

损失函数对异常值收敛速度适用场景
MSELoss敏感误差正态分布
L1Loss鲁棒有异常值
SmoothL1Loss鲁棒中等目标检测 bbox

六、特殊任务

6.1 知识蒸馏:KLDivLoss

用大模型(Teacher)的软标签指导小模型(Student)训练。

T = 4.0  # 温度参数  
loss = nn.KLDivLoss(reduction='batchmean')(
F.log_softmax(student_output / T, dim=1),  # Student: log概率  
F.softmax(teacher_output / T, dim=1)        # Teacher: 概率  

为什么要用温度 T?

原始 logits = [3.0, 1.0, 0.1]
T=1: softmax = [0.84, 0.11, 0.05]  # 很尖锐,信息少
T=4: softmax = [0.45, 0.30, 0.25]  # 更平滑,保留类别间关系

6.2 不平衡分类:Focal Loss

正负样本极度不平衡时(如目标检测中背景:前景 = 1000:1),普通交叉熵会被大量简单负样本主导。

FL(pt)=−αt(1−pt)γlog⁡(pt)FL(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t)FL(pt)=αt(1pt)γlog(pt)

class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super().__init__()        self.alpha = alpha
self.gamma = gamma
def forward(self, predict, target):
ce_loss = F.cross_entropy(predict, target, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
return focal_loss.mean()

核心思想:降低易分样本的权重,让模型专注于难分样本。

样本类型pt(1−pt)2(1-p_t)^2(1pt)2效果
易分样本0.950.0025权重极小
难分样本0.100.81权重大

6.3 对比学习:TripletMarginLoss

学习样本间的相对距离,常用于人脸识别、图像检索。

loss = nn.TripletMarginLoss(margin=1.0)(anchor, positive, negative)

L=max⁡(0,d(a,p)−d(a,n)+margin)L = \max(0, d(a, p) - d(a, n) + margin)L=max(0,d(a,p)d(a,n)+margin)
目标:让同类样本靠近,不同类样本远离。

6.4 序列对齐:CTCLoss

输入输出长度不对齐的场景,如语音识别、OCR。

loss = nn.CTCLoss(blank=0)(log_probs, targets, input_lengths, target_lengths)

核心思想:允许多个输入对应同一输出,通过 blank 符号处理对齐。


七、常见错误及修正

7.1 ❌ CrossEntropyLoss + Softmax(重复激活)

# ❌ 错误写法  
predict = F.softmax(self.fc(x), dim=1)
loss = F.cross_entropy(predict, target)  # 内部又做了一次 softmax!  
# ✅ 正确写法  
predict = self.fc(x)  # 直接用 logitsloss = F.cross_entropy(predict, target)  

7.2 ❌ BCELoss 忘记 Sigmoid

# ❌ 错误写法  
predict = self.fc(x)  # logits,可能是负数!  
loss = F.binary_cross_entropy(predict, target)  # BCELoss 期望输入是概率  
# ✅ 正确写法1:手动加 Sigmoidpredict = torch.sigmoid(self.fc(x))  
loss = F.binary_cross_entropy(predict, target)
# ✅ 正确写法2:用 BCEWithLogitsLoss(推荐)  
predict = self.fc(x)
loss = F.binary_cross_entropy_with_logits(predict, target)

7.3 ❌ target 类型错误

# CrossEntropyLoss 需要 LongTensortarget = torch.LongTensor([0, 1, 2])      # ✅  
target = torch.FloatTensor([0.0, 1.0])    # ❌  
# BCEWithLogitsLoss 需要 FloatTensortarget = torch.FloatTensor([0.0, 1.0])    # ✅  
target = torch.LongTensor([0, 1])         # ❌  

7.4 ❌ 维度不匹配

# CrossEntropyLoss  
predict: [N, C]  # ✅ 二维  
target:  [N]     # ✅ 一维(不是 [N, 1]!)  
# BCEWithLogitsLoss - 形状必须一致  
predict: [N, 1]  # ✅  
target:  [N, 1]  # ✅  

7.5 ❌ 多标签用了 CrossEntropyLoss

# ❌ 错误:CrossEntropyLoss 假设类别互斥  
target = torch.FloatTensor([[1, 1, 0]])  # 同时属于类别0和1  
loss = F.cross_entropy(predict, target)   # 错!  
# ✅ 正确:多标签用 BCEWithLogitsLossloss = F.binary_cross_entropy_with_logits(predict, target)  

八、快速选择决策树

你的任务是什么?
│
├── 分类任务
│   │
│   ├── 类别是否互斥?
│   │   │
│   │   ├── 是(N选1)
│   │   │   └── 几个类别?
│   │   │       ├── 2类 → BCEWithLogitsLoss(更高效)
│   │   │       │         或 CrossEntropyLoss│   │   │       └── >2类 → CrossEntropyLoss│   │   │
│   │   └── 否(N选M,多标签)
│   │       └── BCEWithLogitsLoss
│   │
│   └── 样本是否极度不平衡?
│       └── 是 → FocalLoss│
├── 回归任务
│   │
│   ├── 是否有异常值?
│   │   ├── 有 → L1Loss 或 SmoothL1Loss│   │   └── 无 → MSELoss│   │
│   └── 是目标检测 bbox?
│       └── 是 → SmoothL1Loss│
└── 特殊任务
    ├── 知识蒸馏 → KLDivLoss    ├── 相似度学习 → TripletMarginLoss    └── 序列对齐 → CTCLoss

九、核心要点总结

要点说明
CrossEntropyLoss 内置 Softmax输入是 logits,不要手动加 Softmax
BCEWithLogitsLoss 内置 Sigmoid输入是 logits,不要手动加 Sigmoid
CrossEntropyLoss 的 target 是索引形状 [N]不是 one-hot
BCEWithLogitsLoss 的 target 是浮点数需要 .float() 转换
多标签用 BCE,不是 CrossEntropy每个标签独立,不互斥
回归任务不需要激活函数线性层直接输出数值

十、完整代码模板

多分类

class MultiClassModel(nn.Module):
def __init__(self, input_size, num_classes):
super().__init__()
self.fc = nn.Linear(input_size, num_classes)
self.loss = nn.CrossEntropyLoss()
def forward(self, x, target=None):
logits = self.fc(x)
if target is not None:
return self.loss(logits, target)
return logits
# 推理  
probs = F.softmax(logits, dim=1)
pred_class = torch.argmax(probs, dim=1)

二分类

class BinaryClassModel(nn.Module):
def __init__(self, input_size):
super().__init__()
self.fc = nn.Linear(input_size, 1)
self.loss = nn.BCEWithLogitsLoss()
def forward(self, x, target=None):
logits = self.fc(x)
if target is not None:
return self.loss(logits, target.float())
return torch.sigmoid(logits)
# 推理  
probs = torch.sigmoid(logits)
pred_class = (probs > 0.5).long()

多标签分类

class MultiLabelModel(nn.Module):
def __init__(self, input_size, num_labels):
super().__init__()
self.fc = nn.Linear(input_size, num_labels)
self.loss = nn.BCEWithLogitsLoss()
def forward(self, x, target=None):
logits = self.fc(x)
if target is not None:
return self.loss(logits, target.float())
return torch.sigmoid(logits)
# 推理  
probs = torch.sigmoid(logits)
pred_labels = (probs > 0.5).long()

回归

class RegressionModel(nn.Module):
def __init__(self, input_size):
super().__init__()
self.fc = nn.Linear(input_size, 1)
self.loss = nn.MSELoss()
def forward(self, x, target=None):
predict = self.fc(x)
if target is not None:
return self.loss(predict, target)
return predict

posted @ 2026-01-21 12:04  yangykaifa  阅读(2)  评论(0)    收藏  举报