[报错]-RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
模型输入的数据类型要与模型参数的数据类型一致。
torch.cuda.HalfTensor:对应
np.array(x, dtype = 'float32')
torch.cuda.FloatTensor:对应
np.array(x, dtype = 'float16')

浙公网安备 33010602011771号