torch.cat()的类型转换

torch.cat((TensorA,TensorB))在连接两个不同类型的Tensor的时候会发生类型转换,转换表如下
torch.cat()转换
表的行列按照优先级排列

需要注意的是这个优先级可能会导致数据的溢出,如

[In] torch.cat((torch.LongTensor([1<<31]),torch.HalfTensor([])))
[Out] tensor([inf], dtype=torch.float16)

附:
测试代码

import torch
import pandas as pd

all_types = [
    torch.BoolTensor,
    torch.ByteTensor,
    torch.CharTensor,
    torch.ShortTensor,
    torch.IntTensor,
    torch.LongTensor,
    torch.HalfTensor,
    torch.BFloat16Tensor,
    torch.FloatTensor,
    torch.DoubleTensor,
]
data = [[] for _ in range(len(all_types))]
n = len(all_types)
for i in range(n):
    for j in range(n):
        data[i].append(str(torch.cat((all_types[i](),all_types[j]())).dtype))
a = [str(i.dtype) for i in all_types]

pd.DataFrame(data,index=a,columns=a)
posted @ 2021-09-26 19:41  e-yi  阅读(73)  评论(0)    收藏  举报  来源