torch expand

expand 就是对那个维度为一的进行扩张

import torch

a=torch.tensor([[3.0000, 3.0000],
        [3.0000, 4.0000],
        [3.6000, 3.0000],
        [3.5000, 3.0000]])


a1=a.reshape([a.shape[0],a.shape[1] ,1])

#a1.shape==[4, 2, 1]

#对那个维度为一的进行扩张
a2=a1.expand([4,2,9])

print(a2.shape)

#torch.Size([4, 2, 9])


posted @ 2022-08-19 22:49  luoganttcc  阅读(9)  评论(0)    收藏  举报