Tensor的组合与分块-02

Tensor的组合与分块


  组合操作是指将不同的Tensor叠加起来, 主要有torch.cat()torch.stack()两个函数。 catconcatenate的意思, 是指沿着已有的数据的某一维度进行拼接, 操作后数据的总维数不变, 在进行拼接时, 除了拼接的维度之外, 其他维度必须相同。 而torch.stack()函数指新增维度, 并按照指定的维度进行叠加,

 1 import torch 
 2 
 3 # 创建两个2×2的Tensor
 4 a = torch.Tensor([[1,2],[3,4]])
 5 print(a,a.shape)
 6 
 7 b = torch.Tensor([[5,6],[7,8]])
 8 print(b,b.shape)
 9 
10 # 以第一维进行拼接, 则变成4×2的矩阵
11 c = torch.cat([a,b],0)
12 print(c,c.shape)
13 
14 # 以第二维进行拼接, 则变成2*4的矩阵
15 d = torch.cat([a,b],1)
16 print(d,d.size())
View Code

结果输出:

 1 tensor([[1., 2.],
 2         [3., 4.]]) torch.Size([2, 2])
 3 tensor([[5., 6.],
 4         [7., 8.]]) torch.Size([2, 2])
 5 tensor([[1., 2.],
 6         [3., 4.],
 7         [5., 6.],
 8         [7., 8.]]) torch.Size([4, 2])
 9 tensor([[1., 2., 5., 6.],
10         [3., 4., 7., 8.]]) torch.Size([2, 4])
View Code

 

 1 import torch 
 2 
 3 # 创建两个2×2的Tensor
 4 a = torch.Tensor([[1,2],[3,4]])
 5 print(a,a.shape)
 6 
 7 >>   tensor([[1., 2.],
 8             [3., 4.]]) torch.Size([2, 2])
 9 
10 b = torch.Tensor([[5,6],[7,8]])
11 print(b,b.shape)
12 
13 >>   tensor([[5., 6.],
14             [7., 8.]]) torch.Size([2, 2])
15 
16 # 以第0维进行stack, 叠加的基本单位为序列本身, 即a与b, 因此输出[a, b], 输出维度为2×2×2
17 d=torch.stack([a,b],0)
18 print(d, d.size())
19 >>  tensor([[[1., 2.],
20          [3., 4.]],
21 
22         [[5., 6.],
23          [7., 8.]]]) torch.Size([2, 2, 2])
24 
25 # 以第1维进行stack, 叠加的基本单位为每一行, 输出维度为2×2×2
26 e=torch.stack([a,b],1)
27 print(e, e.shape)
28 
29 >> tensor([[[1., 2.],
30          [5., 6.]],
31 
32         [[3., 4.],
33          [7., 8.]]]) torch.Size([2, 2, 2])
34 
35 # 以第2维进行stack, 叠加的基本单位为每一行的每一个元素, 输出维度为2×2×2
36 f=torch.stack([a,b],2)
37 print(f, f.shape)
38 
39 >> tensor([[[1., 5.],
40          [2., 6.]],
41 
42         [[3., 7.],
43          [4., 8.]]]) torch.Size([2, 2, 2])
View Code

 

   分块则是与组合相反的操作, 指将Tensor分割成不同的子Tensor,主要有torch.chunk()torch.split()两个函数, 前者需要指定分块的数量,而后者则需要指定每一块的大小, 以整型或者list来表示。 具体示例如下 :

 1 import torch 
 2 
 3 a = torch.Tensor([[1,2,3], [4,5,6]])
 4 print(a, a.size())
 5 >> tensor([[1., 2., 3.],
 6         [4., 5., 6.]]) torch.Size([2, 3])
 7 
 8 # 使用chunk, 沿着第0维进行分块, 一共分两块, 因此分割成两个1×3的Tensor
 9 b = torch.chunk(a, 2, 0)
10 print(b)
11 >> (tensor([[1., 2., 3.]]), tensor([[4., 5., 6.]]))
12 
13 # 沿着第1维进行分块, 因此分割成两个Tensor, 当不能整除时, 最后一个的维数会小于前面的
14 # 因此第一个Tensor为2×2, 第二个为2×1
15 c = torch.chunk(a, 2, 1)
16 print(c)
17 >> (tensor([[1., 2.],
18         [4., 5.]]), tensor([[3.],
19         [6.]]))
20 
21 # 使用split, 沿着第0维分块, 每一块维度为2, 由于第一维维度总共为2, 因此相当于没有分割
22 d = torch.split(a, 2, 0)
23 print(d)
24 >> (tensor([[1., 2., 3.],
25         [4., 5., 6.]]),)
26 
27 # 沿着第1维分块, 每一块维度为2, 因此第一个Tensor为2×2, 第二个为2×1
28 e = torch.split(a, 2, 1)
29 print(e)
30 >> (tensor([[1., 2.],
31         [4., 5.]]), tensor([[3.],
32         [6.]]))
33  
34 # split也可以根据输入的list进行自动分块, list中的元素代表了每一个块占的维度
35 f = torch.split(a, [1,2], 1)
36 print(f)
37 >> (tensor([[1.],
38         [4.]]), tensor([[2., 3.],
39         [5., 6.]]))
View Code

 


 

posted @ 2020-09-01 18:32  赵家小伙儿  阅读(808)  评论(0)    收藏  举报