分类问题预测的准确率如何计算
代码片段:
_, train_pred = torch.max(outputs, 1)
train_acc += (train_pred.detach() == labels.detach()).sum().item()
torch.max(outputs, 1): 这个函数用于返回每一行(即每个样本)中最大值及其索引。参数 1 表示在第二个维度(列)上进行操作
返回值是一个元组 (values, indices):values: 每个样本的最大得分(我们用 _ 忽略它,因为我们不需要)。indices: 对应于最大得分的类别索引,这就是我们需要的预测类别
例子:
# 从模型输出中获取每个样本的预测类别,并计算在当前批次中正确预测的样本数量,从而用于计算训练准确率。
import torch
outputs = torch.tensor([[0.1, 0.9, 0.0], # 第一个样本
[0.8, 0.1, 0.1], # 第二个样本
[0.2, 0.2, 0.6]]) # 第三个样本
_, train_pred = torch.max(outputs, 1)
print(_)
print(train_pred)
# tensor([0.9000, 0.8000, 0.6000])
# tensor([1, 0, 2])
labels = torch.tensor([1, 0, 2]) # 真实标签
print(train_pred.detach() == labels.detach()) # 比较
# 结果: tensor([True, True, True]) => [1, 1, 1]
print((train_pred.detach() == labels.detach()).sum()) # 结果为 tensor(3)
# 经过sum()后,True 会被视为 1,False 被视为 0
准确率计算:
train_acc/len(train_set)

浙公网安备 33010602011771号