pytorch nll_loss batch shape
使用nll_loss时,如果想计算batch的loss,假设loss函数输入x的shape为 (N, d, C),其中N是batch_size,d是句子长度,C是vocab_size,标签target y的shape为(N, d)。
nll_loss函数要求输入为 (N, C, d),target为(N, d),则计算时,需要将x的后两维做转置:
loss = torch.nn.functional.nll_loss(x.transpose(1, 2), y)

浙公网安备 33010602011771号