拼接tensor
torch.cat(tensors, dim): 沿指定维度拼接张量。
tensor1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
tensor2 = torch.tensor([[7, 8, 9], [10, 11, 12]])
# dim=0 表示沿着第一个维度(行的方向)进行连接。
concatenated_tensor = torch.cat([tensor1, tensor2], dim=0)
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
# dim=1 表示沿着第二个维度(列的方向)进行连接。
concatenated_tensor = torch.cat([tensor1, tensor2], dim=1)
tensor([[ 1, 2, 3, 7, 8, 9],
[ 4, 5, 6, 10, 11, 12]])
浙公网安备 33010602011771号