CrossEntropyLoss

import torch
import numpy as np

x = torch.randn((64,224,224))

y = torch.rand((64,224,224))

y = (y > 0.5).float()
y = torch.tensor(y)
fun = torch.nn.CrossEntropyLoss()
print(fun(x,y))   //输出:tensor(661.5663)
这种情况,y中的值是表示的是概率
import torch
import numpy as np

x = torch.randn((64,2,224,224))

y = torch.rand((64,224,224))

y = (y > 0.5).float()
y = torch.tensor(y).to(torch.long)
fun = torch.nn.CrossEntropyLoss()
print(fun(x,y))  //输出:tensor(0.9030)
这种情况,x中的2表示分类数,限制y中的值只能是0或1,若是3,则y中的元素值只能是0或1或2。
posted @ 2023-10-10 23:07  oneDonkey  阅读(28)  评论(0)    收藏  举报