Pytorch张量高阶操作

1.Broadcasting

Broadcasting能够实现Tensor自动维度增加(unsqueeze)与维度扩展(expand),以使两个Tensor的shape一致,从而完成某些操作,主要按照如下步骤进行:

  • 从最后面的维度开始匹配(一般后面理解为小维度);
  • 前面插入若干维度,进行unsqueeze操作;
  • 将维度的size从1通过expand变到和某个Tensor相同的维度。

举例:

Feature maps:[4, 32, 14, 14]

Bias:[32, 1, 1](Tip:后面的两个1是手动unsqueeze插入的维度)->[1, 32, 1, 1]->[4, 32, 14, 14]

32

匹配规则(从最后面的维度开始匹配):

  • if current dim=1,expand to same
  • if either has no dim,insert one dim and expand to same
  • otherwise,NOT broadcasting-able

A的维度[4, 32, 8],B的维度[1],[1]->[1, 1, 1]->[4, 32, 8],对应情况1

A的维度[4, 32, 8],B的维度[8],[1]->[1, 1, 8]->[4, 32, 8],对应情况2

A的维度[4, 32, 8],B的维度[4],对应情况3,不能broadcasting

2.拼接与拆分

cat拼接操作

  • 功能:通过dim指定维度,在当前指定维度上直接拼接
  • 默认是dim=0
  • 指定的dim上,维度可以不相同,其他dim上维度必须相同,不然会报错
1 a1=torch.rand(4,3,32,32)
2 a2=torch.rand(5,3,32,32)
3 print(torch.cat([a1,a2],dim=0).shape)    #torch.Size([9, 3, 32, 32])
4 
5 a3=torch.rand(4,1,32,32)
6 print(torch.cat([a1,a3],dim=1).shape)    #torch.Size([4, 4, 32, 32])
7 
8 a4=torch.rand(4,3,16,32)
9 print(torch.cat([a1,a4],dim=2).shape)    #torch.Size([4, 3, 48, 32])

stack拼接操作

  • 与cat不同的是,stack是在拼接的同时,在指定dim处插入维度后拼接(create new dim
  • stack需要保证两个Tensor的shape是一致的,这就像是有两类东西,它们的其它属性都是一样的(比如男的一张表,女的一张表)。使用stack时候要指定一个维度位置,在那个位置前会插入一个新的维度,因为是两类东西合并过来所以这个新的维度size是2,通过指定这个维度是0或者1来选择性别是男还是女。

  • 默认dim=0
1 a1=torch.rand(4,3,32,32)
2 a2=torch.rand(4,3,32,32)
3 print(torch.stack([a1,a2],dim=1).shape)  #torch.Size([4, 2, 3, 32, 32])  
左边起第二个维度取0时,取上半部分即a1,左边起第二个维度取1时,取下半部分即a2
4 print(torch.stack([a1,a2],dim=2).shape) #torch.Size([4, 3, 2, 32, 32])

split分割操作

  • 指定拆分dim
  • 按长度拆分,给定拆分后的数据大小
1 c=torch.rand(3,32,8)
2 
3 aa,bb=c.split([1,2],dim=0)     
4 print(aa.shape,bb.shape)            #torch.Size([1, 32, 8]) torch.Size([2, 32, 8])
5 
6 aa,bb,cc=c.split([1,1,1],dim=0)     #或者写成aa,bb,cc=c.split(1,dim=0) 
7 print(aa.shape,bb.shape,cc.shape)   #torch.Size([1, 32, 8]) torch.Size([1, 32, 8]) torch.Size([1, 32, 8])

chunk分割操作

  • chunk是在指定dim下按个数拆分,给定平均拆分的个数
  • 如果给定个数不能平均拆分当前维度,则会取比给定个数小的,能平均拆分数据的,最大的个数
  • dim默认是0
1 c=torch.rand(3,32,8)
2 d=torch.rand(2,32,8)
3 aa,bb=d.chunk(2,dim=0)
4 print(aa.shape,bb.shape)            #torch.Size([1, 32, 8]) torch.Size([1, 32, 8])
5 
6 aa,bb=c.chunk(2,dim=0)
7 print(aa.shape,bb.shape)            #torch.Size([2, 32, 8]) torch.Size([1, 32, 8])

3.基本运算

加法(a+b、torch.add(a,b))

减法(a-b、torch.sub(a,b))

乘法(*、torch.mul(a,b))对应元素相乘

除法(/、torch.div(a,b))对应元素相除,//整除

1 a = torch.rand(3, 4)
2 b = torch.rand(4)
3  
4 c1 = a + b
5 c2 = torch.add(a, b)
6 print(c1.shape, c2.shape)                #torch.Size([3, 4]) torch.Size([3, 4])
7 print(torch.all(torch.eq(c1, c2)))       #tensor(True)

矩阵乘法

torch.mm(only for 2d,不推荐使用)

torch.matmul(推荐)

@

1 a=torch.rand(2,1)
2 b=torch.rand(1,2)
3 print(torch.mm(a,b).shape)          #torch.Size([2, 2])
4 print(torch.matmul(a,b).shape)      #torch.Size([2, 2])
5 print((a@b).shape)                  #torch.Size([2, 2])

应用于矩阵降维

1 x=torch.rand(4,784)
2 w=torch.rand(512,784)             #channel-out对应512,channel-in对应784
3 print((x@w.t()).shape)            #torch.Size([4, 512]) Tip:.t()只适用于二维

多维矩阵相乘

对于高维的Tensor(dim>2),定义其矩阵乘法仅在最后的两个维度上,要求前面的维度必须保持一致,就像矩阵的索引一样并且运算操作符只有torch.matmul()。

1 a=torch.rand(4,3,28,64)
2 b=torch.rand(4,3,64,32)
3 print(torch.matmul(a,b).shape)    #torch.Size([4, 3, 28, 32])
4        
5 c=torch.rand(4, 1, 64, 32)
6 print(torch.matmul(a,c).shape)    #torch.Size([4, 3, 28, 32])
7 
8 d=torch.rand(4,64,32)
9 print(torch.matmul(a,d).shape)    #报错

Tip:这种情形下的矩阵相乘,"矩阵索引维度"如果符合Broadcasting机制,也会自动做广播,然后相乘。

次方pow、**操作

1 a = torch.full([2, 2], 3)  
2 b = a.pow(2)                 #也可以a**2
3 print(b)
4 #tensor([[9., 9.],
5 #        [9., 9.]])

开方sqrt、**操作

1 #接上面
2 c = b.sqrt()   #也可以a**(0.5)
3 print(c)
4 #tensor([[3., 3.],
5 #        [3., 3.]])
6 d = b.rsqrt()  #平方根的倒数
7 print(d)
8 #tensor([[0.3333, 0.3333],
9 #        [0.3333, 0.3333]])

指数exp与对数log运算

log是以自然对数为底数的,以2为底的用log2,以10为底的用log10。

1 a = torch.exp(torch.ones(2, 2))  #得到2*2的全是e的Tensor
2 print(a)
3 #tensor([[2.7183, 2.7183],
4 #        [2.7183, 2.7183]])
5 print(torch.log(a))              #取自然对数
6 #tensor([[1., 1.],
7 #        [1., 1.]])

近似值运算

1 a = torch.tensor(3.14)
2 print(a.floor(), a.ceil(), a.trunc(), a.frac())  #取下,取上,取整数,取小数
3 #tensor(3.) tensor(4.) tensor(3.) tensor(0.1400)
4 b = torch.tensor(3.49)
5 c = torch.tensor(3.5)
6 print(b.round(), c.round())                      #四舍五入tensor(3.) tensor(4.)

裁剪运算clamp

对Tensor中的元素进行范围过滤,不符合条件的可以把它变换到范围内部(边界)上,常用于梯度裁剪(gradient clipping),即在发生梯度离散或者梯度爆炸时对梯度的处理,实际使用时可以查看梯度的(L2范数)模来看看需不需要做处理:w.grad.norm(2)

 1 grad = torch.rand(2, 3) * 15      #0~15随机生成
 2 print(grad.max(), grad.min(), grad.median())  #tensor(12.9533) tensor(1.5625) tensor(11.1101)
 3  
 4 print(grad)
 5 #tensor([[12.7630, 12.9533,  7.6125],
 6 #        [11.1101, 12.4215,  1.5625]])
 7 print(grad.clamp(10))             #最小是10,小于10的都变成10
 8 #tensor([[12.7630, 12.9533, 10.0000],
 9 #        [11.1101, 12.4215, 10.0000]])
10 print(grad.clamp(3, 10))          #最小是3,小于3的都变成3;最大是10,大于10的都变成10
11 #tensor([[10.0000, 10.0000,  7.6125],
12 #        [10.0000, 10.0000,  3.0000]])

4.统计属性

范数norm

Vector norm 和matrix norm区别

 1 a=torch.full([8],1)
 2 b=a.view(2,4)
 3 c=a.view(2,2,2)
 4 print(b)
 5 #tensor([[1., 1., 1., 1.],
 6 #        [1., 1., 1., 1.]])
 7 print(c)
 8 #tensor([[[1., 1.],
 9 #         [1., 1.]],
10 #        [[1., 1.],
11 #         [1., 1.]]])
12 
13 #求L1范数(所有元素绝对值求和)
14 print(a.norm(1),b.norm(1),c.norm(1))            #tensor(8.) tensor(8.) tensor(8.)
15 #求L2范数(所有元素的平方和再开根)
16 print(a.norm(2),b.norm(2),c.norm(2))            #tensor(2.8284) tensor(2.8284) tensor(2.8284)
17 
18 # 在b的1号维度上求L1范数
19 print(b.norm(1, dim=1))            #tensor([4., 4.])
20 # 在b的1号维度上求L2范数
21 print(b.norm(2, dim=1))            #tensor([2., 2.])
22  
23 # 在c的0号维度上求L1范数
24 print(c.norm(1, dim=0))
25 #tensor([[2., 2.],
26 #        [2., 2.]])
27 # 在c的0号维度上求L2范数
28 print(c.norm(2, dim=0))
29 #tensor([[1.4142, 1.4142],
30 #        [1.4142, 1.4142]])

均值mean、累加sum、最小min、最大max、累积prod

最大值最小值索引argmax、argmin

 1 b = torch.arange(8).reshape(2, 4).float()
 2 print(b)
 3 #均值,累加,最小,最大,累积
 4 print(b.mean(), b.sum(), b.min(), b.max(), b.prod())       #tensor(3.5000) tensor(28.) tensor(0.) tensor(7.) tensor(0.)  
 5 
 6 #不指定维度,输出打平后的最小最大值索引
 7 print(b.argmax(), b.argmin())                              #tensor(7) tensor(0)
 8 #指定维度1,输出每一行最大值所在的索引
 9 print(b.argmax(dim=1))                                     #tensor([3, 3])
10 #指定维度0,输出每一列最大值所在的索引
11 print(b.argmax(dim=0))                                     #tensor([1, 1, 1, 1])

Tip:上面的argmax、argmin操作默认会将Tensor打平后取最大值索引和最小值索引,如果不希望Tenosr打平,而是求给定维度上的索引,需要指定在哪一个维度上求最大值或最小值索引。

dim、keepdim

比方说shape=[4,10],dim=1时,保留第0个维度,即max输出会有4个值。

 1 a=torch.rand(4,10)
 2 print(a.max(dim=1))                                      #返回结果和索引
 3 # torch.return_types.max(
 4 # values=tensor([0.9770, 0.8467, 0.9866, 0.9064]),
 5 # indices=tensor([4, 2, 2, 4]))
 6 print(a.argmax(dim=1))                                   #tensor([4, 2, 2, 4])
 7 
 8 print(a.max(dim=1,keepdim=True))
 9 # torch.return_types.max(
10 # values=tensor([[0.9770],
11 #         [0.8467],
12 #         [0.9866],
13 #         [0.9064]]),
14 # indices=tensor([[4],
15 #         [2],
16 #         [2],
17 #         [4]]))
18 print(a.argmax(dim=1,keepdim=True))
19 # tensor([[4],
20 #         [2],
21 #         [2],
22 #         [4]])

Tip:使用keepdim=True可以保持应有的dim,即仅仅是将求最值的那个dim的size变成了1,返回的结果是符合原Tensor语义的。

取前k大topk(largest=True)/前k小(largest=False)的概率值及其索引

第k小(kthvalue)的概率值及其索引

 1 # 2个样本,分为10个类别的置信度
 2 d = torch.randn(2, 10)  
 3 # 最大概率的3个类别
 4 print(d.topk(3, dim=1))  
 5 # torch.return_types.topk(
 6 # values=tensor([[1.6851, 1.5693, 1.5647],
 7 #                [0.8492, 0.4311, 0.3302]]),
 8 # indices=tensor([[9, 1, 4],
 9 #                 [6, 2, 4]]))
10 
11 # 最小概率的3个类别
12 print(d.topk(3, dim=1, largest=False))  
13 # torch.return_types.topk(
14 # values=tensor([[-1.2583, -0.7617, -0.4518],
15 #         [-1.5011, -0.9987, -0.9042]]),
16 # indices=tensor([[6, 7, 2],
17 #         [3, 1, 9]]))
18 
19 # 求第8小概率的类别(一共10个那就是第3大,正好对应上面最大概率的3个类别的第3列)
20 print(d.kthvalue(8, dim=1))  
21 # torch.return_types.kthvalue(
22 # values=tensor([1.5647, 0.3302]),
23 # indices=tensor([4, 4]))

比较操作

>,>=,<,<=,!=,==

torch.eq(a,b)、torch.equal(a,b)

 1 a=torch.randn(2,3)
 2 b=torch.randn(2,3)
 3 print(a>0)
 4 print(torch.gt(a,0))
 5 # tensor([[False,  True,  True],
 6 #         [True, False, False]])
 7 
 8 
 9 print(torch.equal(a,a))        #True
10 print(torch.eq(a,a))
11 # tensor([[True, True, True],
12 #         [True, True, True]])

5.高阶操作

where

使用C=torch.where(condition,A,B)其中A,B,C,condition是shape相同的Tensor,C中的某些元素来自A,某些元素来自B,这由condition中对应位置的元素是1还是0来决定。如果condition对应位置元素是1,则C中的该位置的元素来自A中的该位置的元素,如果condition对应位置元素是0,则C中的该位置的元素来自B中的该位置的元素。

 1 cond=torch.tensor([[0.6,0.1],[0.8,0.7]])
 2 
 3 a=torch.tensor([[1,2],[3,4]])
 4 b=torch.tensor([[4,5],[6,7]])
 5 print(cond>0.5)
 6 # tensor([[ True, False],
 7 #         [ True,  True]])
 8 print(torch.where(cond>0.5,a,b))
 9 # tensor([[1, 5],
10 #         [3, 4]])

gather

torch.gather(input, dim, index, out=None)对元素实现一个查表映射的操作:

 1 prob=torch.randn(4,10)
 2 idx=prob.topk(dim=1,k=3)
 3 print(idx)
 4 # torch.return_types.topk(
 5 # values=tensor([[ 1.6009,  0.7975,  0.6671],
 6 #                [ 1.0937,  0.9888,  0.7749],
 7 #                [ 1.1727,  0.6124, -0.3543],
 8 #                [ 1.1406,  0.8465,  0.6256]]),
 9 # indices=tensor([[8, 5, 4],
10 #                 [6, 9, 8],
11 #                 [1, 3, 8],
12 #                 [6, 1, 3]]))
13 idx=idx[1]
14 
15 label=torch.arange(10)+100
16 print(torch.gather(label.expand(4,10), dim=1, index=idx.long()))
17 # tensor([[108, 105, 104],
18 #         [106, 109, 108],
19 #         [101, 103, 108],
20 #         [106, 101, 103]])

label=[[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],

      [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],

           [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],

           [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]]

index=[[8, 5, 4],

       [6, 9, 8],

            [1, 3, 8],

            [6, 1, 3]]

gather的含义就是利用index来索引input特定位置的数值。

补充scatter_

scatter_(dim, index, src)将src中数据根据index中的索引按照dim的方向填进input中

细节再补充。。。

 

posted @ 2020-07-03 23:10  最咸的鱼  阅读(4208)  评论(0编辑  收藏  举报