pytorch-day03
1、统计属性
1 # 1、norm:范数 2 a = torch.full([8], 1) 3 b = a.view(2, 4) 4 c = a.view(2, 2, 2) 5 print(a.norm(1), b.norm(1), c.norm(1)) # 1范数 tensor(8.) tensor(8.) tensor(8.) 6 print(a.norm(2), b.norm(2), c.norm(2)) # 2范数 tensor(2.8284) tensor(2.8284) tensor(2.8284) 7 print(b.norm(1, dim=0)) # 行和1范数 tensor([2., 2., 2., 2.]) 8 print(b.norm(1, dim=1)) # 列和1范数 tensor([4., 4.]) 9 10 # 2、mean sum prod max min argmin argmax 11 # 以下操作都是先打平 12 a = torch.arange(8).view(2, 4).float() 13 print(a.min(), a.max(), a.mean(), a.sum(), a.prod()) # prod:累乘 14 # 返回索引 15 print(a.argmax(), a.argmin()) # tensor(7) tensor(0) 16 # 编写不打平的情况,即指定维度dim, keepdim:保持以前的shape 17 a = torch.rand(4, 10) 18 print(a.argmax(), a.argmax(dim=1)) # tensor(38) tensor([2, 8, 5, 8]) 19 print(a.max(dim=1)) 20 print(a.max(dim=1, keepdim=True)) 21 22 # 3、topk 23 print(a.topk(3, dim=1)) # 从每一行中选出3列最大的 24 print(a.topk(3, dim=1, largest=False)) # 从每一行中选出3列最小的 25 # 4、kthvalue 26 print(a.kthvalue(8, dim=1)) # 返回第8小的(只能是小)数值,即第3大的数值 27 print(a.kthvalue(8, dim=1, keepdim=True)) # 返回第8小的(只能是小)数值,即第3大的数值 28 29 # 5、compare 30 print(a > 0.5) 31 torch.equal(a, b) # 比较每个位置的元素 32 torch.eq(a, b) # 比较每个位置的元素
2、Tensor高阶
where、gather

浙公网安备 33010602011771号