>>> p
tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
>>> torch.stack([p,p],dim=0)
tensor([[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]])
>>> torch.stack([p,p],dim=0).shape
torch.Size([2, 4, 5])
>>> p.shape
torch.Size([4, 5])
>>> torch.stack([p,p],dim=1).shape
torch.Size([4, 2, 5])
>>> torch.stack([p,p],dim=2).shape
torch.Size([4, 5, 2])