Torch和Numpy——查看形状类型
输入
1 import numpy as np 2 import torch 3 4 a = np.array([[1,2],[3,4]]) 5 print(a.shape,np.shape(a),a.dtype) 6 print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') 7 8 a = a.astype(np.float32) 9 print(a.dtype) 10 print('***************************************************') 11 12 b = torch.tensor([[1,2],[3,4]]) 13 print(b.shape,b.size(),b.type()) 14 print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&") 15 16 b = b.float() 17 print(b.dtype)
输出
1 (2, 2) (2, 2) int32 2 @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 3 float32 4 *************************************************** 5 torch.Size([2, 2]) torch.Size([2, 2]) torch.LongTensor 6 &&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&& 7 torch.float32

浙公网安备 33010602011771号