pytorch-day03

1、统计属性

 1 # 1、norm:范数
 2 a = torch.full([8], 1)
 3 b = a.view(2, 4)
 4 c = a.view(2, 2, 2)
 5 print(a.norm(1), b.norm(1), c.norm(1))  # 1范数 tensor(8.) tensor(8.) tensor(8.)
 6 print(a.norm(2), b.norm(2), c.norm(2))  # 2范数 tensor(2.8284) tensor(2.8284) tensor(2.8284)
 7 print(b.norm(1, dim=0))  # 行和1范数 tensor([2., 2., 2., 2.])
 8 print(b.norm(1, dim=1))  # 列和1范数 tensor([4., 4.])
 9 
10 # 2、mean sum prod max min argmin argmax
11 # 以下操作都是先打平
12 a = torch.arange(8).view(2, 4).float()
13 print(a.min(), a.max(), a.mean(), a.sum(), a.prod())  # prod:累乘
14 # 返回索引
15 print(a.argmax(), a.argmin())  # tensor(7) tensor(0)
16 # 编写不打平的情况,即指定维度dim, keepdim:保持以前的shape
17 a = torch.rand(4, 10)
18 print(a.argmax(), a.argmax(dim=1))  # tensor(38) tensor([2, 8, 5, 8])
19 print(a.max(dim=1))
20 print(a.max(dim=1, keepdim=True))
21 
22 # 3、topk
23 print(a.topk(3, dim=1))  # 从每一行中选出3列最大的
24 print(a.topk(3, dim=1, largest=False))  # 从每一行中选出3列最小的
25 # 4、kthvalue
26 print(a.kthvalue(8, dim=1))  # 返回第8小的(只能是小)数值,即第3大的数值
27 print(a.kthvalue(8, dim=1, keepdim=True))  # 返回第8小的(只能是小)数值,即第3大的数值
28 
29 # 5、compare
30 print(a > 0.5)
31 torch.equal(a, b)  # 比较每个位置的元素
32 torch.eq(a, b)  # 比较每个位置的元素

2、Tensor高阶

  where、gather

posted @ 2020-07-24 10:26  小吴的日常  阅读(87)  评论(0)    收藏  举报