torch.cat()的类型转换
torch.cat((TensorA,TensorB))在连接两个不同类型的Tensor的时候会发生类型转换,转换表如下

表的行列按照优先级排列
需要注意的是这个优先级可能会导致数据的溢出,如
[In] torch.cat((torch.LongTensor([1<<31]),torch.HalfTensor([])))
[Out] tensor([inf], dtype=torch.float16)
附:
测试代码
import torch
import pandas as pd
all_types = [
torch.BoolTensor,
torch.ByteTensor,
torch.CharTensor,
torch.ShortTensor,
torch.IntTensor,
torch.LongTensor,
torch.HalfTensor,
torch.BFloat16Tensor,
torch.FloatTensor,
torch.DoubleTensor,
]
data = [[] for _ in range(len(all_types))]
n = len(all_types)
for i in range(n):
for j in range(n):
data[i].append(str(torch.cat((all_types[i](),all_types[j]())).dtype))
a = [str(i.dtype) for i in all_types]
pd.DataFrame(data,index=a,columns=a)

浙公网安备 33010602011771号