Pytorch:合并分割

1 前言

记录一下Pytorch中对tensor合并分割的方法

2 合并

Pytorch中对tensor合并的方法有两种:
torch.cat()
torch.stack()

其中,torch.cat()直接将两个变量进行拼接,不会产生新的维度
torch.stack()则会将tensor堆叠,产生新的维度

tensor1 = torch.randn(2,3)
tensor2 = torch.randn(2,3)
print(tensor1)
print(tensor2)
# out:
tensor([[ 1.3124, -0.6630, -1.1289],
        [-0.0913,  0.7382,  0.4581]])
tensor([[-0.8929, -1.3781, -0.6344],
       [-0.0994,  0.5217, -2.2306]])
 
tensor_cat = torch.cat([tensor1,tensor2])
print(f"tensor_out:{tensor_cat}")
print(f"size of tensor_out:{tensor_cat.size()}")
tensor_stack = torch.stack([tensor1,tensor2])
print(f"tensor_stack:{tensor_stack}")
print(f"size of tensor_stack:{tensor_stack.size()}")

# out
tensor_out:tensor([[ 1.3124, -0.6630, -1.1289],
        [-0.0913,  0.7382,  0.4581],
        [-0.8929, -1.3781, -0.6344],
        [-0.0994,  0.5217, -2.2306]])
size of tensor_out:torch.Size([4, 3])
tensor_stack:tensor([[[ 1.3124, -0.6630, -1.1289],
         [-0.0913,  0.7382,  0.4581]],

        [[-0.8929, -1.3781, -0.6344],
         [-0.0994,  0.5217, -2.2306]]])
size of tensor_stack:torch.Size([2, 2, 3])

torch.vstack能够完成与torch.cat一样的效果
torch.vstack能够按顺序垂直(行)堆叠张量

tensor_vstack = torch.vstack([tensor1,tensor2])
print(f"tensor_vstack:{tensor_vstack}")
print(f"size of tensor_vstack:{tensor_vstack.size()}")

# out:
tensor_vstack:tensor([[ 1.3124, -0.6630, -1.1289],
        [-0.0913,  0.7382,  0.4581],
        [-0.8929, -1.3781, -0.6344],
        [-0.0994,  0.5217, -2.2306]])
size of tensor_vstack:torch.Size([4, 3])

torch.hstack则是能够按水平顺序堆叠张量(按列)

tensor_hstack = torch.hstack([tensor1,tensor2])
print(f"tensor_hstack:{tensor_hstack}")
print(f"size of tensor_hstack:{tensor_hstack.size()}")

# out:
tensor_hstack:tensor([[ 1.3124, -0.6630, -1.1289, -0.8929, -1.3781, -0.6344],
        [-0.0913,  0.7382,  0.4581, -0.0994,  0.5217, -2.2306]])
size of tensor_hstack:torch.Size([2, 6])

3 分割

Pytorch中对tensor合并的方法有两种:
torch.split()
torch.chunk()

其中,splittensor拆分为多块,每个块都是原始tensor视图

chunk则是按照dimtensor分割为chunkstensor块,返回块的元组

def split( tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0 ) -> Tuple[Tensor, ...]: r"""Splits the tensor into chunks. Each chunk is a view of the original tensor. If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension :attr:`dim` is not divisible by :attr:`split_size`. If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according to :attr:`split_size_or_sections`. Args: tensor (Tensor): tensor to split. split_size_or_sections (int) or (list(int)): size of a single chunk or list of sizes for each chunk dim (int): dimension along which to split the tensor.
torch.chunk(input, chunks, dim=0) → List of Tensors
"""
Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor.
Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.
Parameters:
    input (Tensor) – the tensor to split
    chunks (int) – number of chunks to return
    dim (int) – dimension along which to split the tensor

split:

tensor = torch.randn(10).reshape(5,2)
print(f"tensor:{tensor}")
torch.split(tensor,2)

# out:
tensor:tensor([[ 0.9619,  0.6095],
        [-1.8024, -0.1534],
        [ 1.7452,  0.4705],
        [-0.8512,  0.3175],
        [-0.0290, -0.1422]])

(tensor([[ 0.9619,  0.6095],
         [-1.8024, -0.1534]]),
 tensor([[ 1.7452,  0.4705],
         [-0.8512,  0.3175]]),
 tensor([[-0.0290, -0.1422]]))

torch.split(tensor,[2,3])

# out:
(tensor([[-1.5071, -0.0346],
         [-0.6429,  0.5917]]),
 tensor([[ 0.2722,  0.3824],
         [ 0.6135,  0.7926],
         [-0.5771, -0.4590]]))

chunk:

torch.chunk(tensor, 2 ,dim=1)

# out:
(tensor([[-1.5071],
         [-0.6429],
         [ 0.2722],
         [ 0.6135],
         [-0.5771]]),
 tensor([[-0.0346],
         [ 0.5917],
         [ 0.3824],
         [ 0.7926],
         [-0.4590]]))

torch.chunk(tensor, 2 ,dim=0)

# out:
(tensor([[-1.5071, -0.0346],
         [-0.6429,  0.5917],
         [ 0.2722,  0.3824]]),
 tensor([[ 0.6135,  0.7926],
         [-0.5771, -0.4590]]))

4 Ref

  1. https://aiaer.blog.csdn.net/article/details/125086792?spm=1001.2101.3001.6650.1&utm_medium=distribute.pc_relevant.none-task-blog-2~default~BlogCommendFromBaidu~Rate-1-125086792-blog-108471904.235^v43^pc_blog_bottom_relevance_base6&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2~default~BlogCommendFromBaidu~Rate-1-125086792-blog-108471904.235^v43^pc_blog_bottom_relevance_base6&utm_relevant_index=2
posted @ 2024-06-19 16:01  liuliu55  阅读(98)  评论(0)    收藏  举报