pytorch数据类型与转换

torch定义了7种cpu tensor类型和8中gpu tensor类型

 

 使用时,直接传入数字,就是按照形状初始化

torch.FloatTensor(2,3)

torch.DoubleTensor(2,3)

torch.ByteTensor(2,3)

torch.CharTensor(2,3)

torch.ShortTensor(2,3)

torch.IntTensor(2,3)

torch.LongTensor(2,3)

注意在torch中对不同dtype的tensor原则上是不能混合运算的。

可以对tensor进行强行转换。

1   直接在tensor后面接.dtype()进行转换

import torch

a=torch.Tensor(3.1415)

a=a.float()

a=a.double()

a=a.int()

也可以对模型进行骚操作,即对模型的输出结果进行强制转化

self.lstm=nn.LSTM(input_size=self.emb_dim,hidden_size=n_lstm_units,droput=1-keep_prob).double()

这样就直接将lstm的模型进行了强转化。但是不会影响梯度的backward

 2   使用to转换

torch.lzeros(1,2).to(torch.double)

3 使用type()进行转化

torch.zeros(1,2).type(torch.double)

 

可以设置获取默认tensor类型

torch.set_default_tensor_type(torch.double)

torch.get_default_tensor_type()

posted @ 2021-08-26 11:46  大大的海棠湾  阅读(942)  评论(0)    收藏  举报