pytorch one hot

print(torch.nn.functional.one_hot(t, num_classes=7))

有个坑,使用的时候必须转换为 torch.int64 类型,不然会报错

t = t.to(torch.int64)
posted @ 2021-01-20 16:48  consolexinhun  阅读(310)  评论(0编辑  收藏  举报