torch.cat()

《动手pytorch》

  cat :cconcatenate,拼接

1 import torch
2 n_train,n_test = 100,100
3 features = torch.randn((n_train+n_test))
4 poly_features = torch.cat((features,torch.pow(features,2),torch.pow(features,3)),1)

  torch.cat((A,B),dim=0) : 按行拼接,列数不变

   torch.cat((A,B),dim=1) : 按列拼接,行数不变

posted @ 2020-08-16 20:16  此间一看客  阅读(841)  评论(0)    收藏  举报