CLASS torch.nn.CrossEntropyLoss

image
image
image
image
image
image

# Example of target with class indices
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()
# Example of target with class probabilities
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)
output.backward()

posted on 2022-12-08 22:51  朴素贝叶斯  阅读(27)  评论(0编辑  收藏  举报

导航