pytorch数据类型与转换
torch定义了7种cpu tensor类型和8中gpu tensor类型

使用时,直接传入数字,就是按照形状初始化
torch.FloatTensor(2,3)
torch.DoubleTensor(2,3)
torch.ByteTensor(2,3)
torch.CharTensor(2,3)
torch.ShortTensor(2,3)
torch.IntTensor(2,3)
torch.LongTensor(2,3)
注意在torch中对不同dtype的tensor原则上是不能混合运算的。
可以对tensor进行强行转换。
1 直接在tensor后面接.dtype()进行转换
import torch
a=torch.Tensor(3.1415)
a=a.float()
a=a.double()
a=a.int()
也可以对模型进行骚操作,即对模型的输出结果进行强制转化
self.lstm=nn.LSTM(input_size=self.emb_dim,hidden_size=n_lstm_units,droput=1-keep_prob).double()
这样就直接将lstm的模型进行了强转化。但是不会影响梯度的backward
2 使用to转换
torch.lzeros(1,2).to(torch.double)
3 使用type()进行转化
torch.zeros(1,2).type(torch.double)
可以设置获取默认tensor类型
torch.set_default_tensor_type(torch.double)
torch.get_default_tensor_type()

浙公网安备 33010602011771号