pytorch nll_loss batch shape

使用nll_loss时,如果想计算batch的loss,假设loss函数输入x的shape为  (N, d, C),其中N是batch_size,d是句子长度,C是vocab_size,标签target y的shape为(N, d)。

nll_loss函数要求输入为 (N, C, d),target为(N, d),则计算时,需要将x的后两维做转置:

loss = torch.nn.functional.nll_loss(x.transpose(1, 2), y)

 


posted @ 2021-03-14 15:45  AliceYing  阅读(169)  评论(0)    收藏  举报