损失函数的入参形状

# 假设 logits 形状 [1,3,5],
logits = torch.randn(1, 3, 5)
targets = torch.tensor([[1, 2, 1, 1, 2]])  # 形状为 [1,5]
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, targets)
print(loss)  # 输出标量损失值

任务类型 ​输入形状 (input) ​目标形状 (target)
图像分类 (batch_size, num_classes) (batch_size,)
序列生成 (batch_size, num_classes, seq_len) (batch_size, seq_len)
语义分割 (batch_size, num_classes, H, W) (batch_size, H, W)

logits 的形状必须为[N,C,....]
targets 的形状必须为[N,.....]

N 为batch_size
C 为预测的类别数

posted @ 2025-03-09 11:48  xiezhengcai  阅读(12)  评论(0)    收藏  举报