(y_hat.argmax(dim=1) ==lable).sum().cpu().item()

        print(y_hat.argmax(dim=1))
        print(y_hat.argmax(dim=1) ==lable)
        print((y_hat.argmax(dim=1) ==lable).sum())
        print((y_hat.argmax(dim=1) ==lable).sum().cpu())     
        print((y_hat.argmax(dim=1) ==lable).sum().cpu().item())
输出:

tensor([4, 4, 5, 0, 4, 2, 8, 5, 8, 2, 4, 4, 4, 2, 8, 4, 1, 8, 2, 5, 0, 7, 4, 4,
6, 5, 6, 2, 5, 3, 5, 4, 4, 8, 4, 5, 2, 4, 2, 4, 6, 2, 5, 6, 5, 8, 4, 4,
4, 2], device='cuda:0')

第一个输出把每行最大的索引输出
tensor([False, False, False, False, False, False, False, False, True, False,
True, False, False, True, False, False, False, False, True, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, True, False, False, False, False, False,
False, True, False, False, False, False, False, False, False, False],
device='cuda:0')

第二个输出判断索引和lable是否相等,相等为true否则为false。
tensor(6, device='cuda:0')

第三个输出进行sum求和true算1,flase算0。
tensor(6)

第四个输出将cuda变为cpu
6

第五个item将tensor变为整形

posted @ 2021-10-09 17:29  祥瑞哈哈哈  阅读(539)  评论(0)    收藏  举报