torch.cat()
《动手pytorch》
cat :cconcatenate,拼接
1 import torch 2 n_train,n_test = 100,100 3 features = torch.randn((n_train+n_test)) 4 poly_features = torch.cat((features,torch.pow(features,2),torch.pow(features,3)),1)
torch.cat((A,B),dim=0) : 按行拼接,列数不变
torch.cat((A,B),dim=1) : 按列拼接,行数不变

浙公网安备 33010602011771号