numpy轴操作

今天讲点干货的东西,关于numpy中轴操作(axes)的内容。
在numpy中有些函数支持axes参数,比如max、min、mean等,关于其轴的处理可能很容易给人感觉云里雾里的,毕竟高维之后还真不好确定。
这里先以一个简单例子进行说明:

>>> import numpy as np
>>> np.random.seed(42)
>>> a = np.arange(120)                                                 
>>> a                                                                  
array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119])                                                 
>>> np.random.shuffle(a)                                               
>>> a                                                                  
array([ 44,  47,   4,  55,  26,  64,  73,  10,  40, 107,  18,  62,  11,
        36,  89,  91, 109,   0,  88, 104,  65,  45,  31,  70,  42,  12,
        15, 114,  76,  97,  24,  78,  22,  96,  56, 110,  30,  53, 118,
         9,  33,  25,  69,  28,  98,  85,   5,  90,  68,  39,  49,  35,
        16,  66,  34, 113,   7,  43,  72,  67,  83,  27,  19,  95, 100,
         8,  13,  84,   3,  17,  38, 117,   6,  77, 111,  94,  54,  50,
        80,  46,  81,  61, 116,  79,  93,  41,  58,  48, 101,  57,  75,
        32, 112,  59,  63, 105,  37,  29, 115,   1,  52,  21,   2,  23,
       103,  99,  87, 108,  74,  86,  82, 119,  20,  60,  71, 106,  14,
        92,  51, 102])                                                 

接下来对其进行重塑让其成为4维向量:

>>> b = a.reshape(2,3,4,5)
>>> b.shape
(2, 3, 4, 5)

假设现在要进行max操作,那么每个轴其结果是什么呢?
带着这个问题,我们先将变量打印出来,再对其进行分解:

>>> b                               
array([[[[ 44,  47,   4,  55,  26], 
         [ 64,  73,  10,  40, 107], 
         [ 18,  62,  11,  36,  89], 
         [ 91, 109,   0,  88, 104]],
                                    
        [[ 65,  45,  31,  70,  42], 
         [ 12,  15, 114,  76,  97], 
         [ 24,  78,  22,  96,  56], 
         [110,  30,  53, 118,   9]],
                                    
        [[ 33,  25,  69,  28,  98], 
         [ 85,   5,  90,  68,  39], 
         [ 49,  35,  16,  66,  34], 
         [113,   7,  43,  72,  67]]]
                                    
                                    
       [[[ 83,  27,  19,  95, 100], 
         [  8,  13,  84,   3,  17], 
         [ 38, 117,   6,  77, 111], 
         [ 94,  54,  50,  80,  46]],
                                    
        [[ 81,  61, 116,  79,  93], 
         [ 41,  58,  48, 101,  57], 
         [ 75,  32, 112,  59,  63], 
         [105,  37,  29, 115,   1]],
                                    
        [[ 52,  21,   2,  23, 103], 
         [ 99,  87, 108,  74,  86], 
         [ 82, 119,  20,  60,  71], 
         [106,  14,  92,  51, 102]]]
>>> b[0]                            
array([[[ 44,  47,   4,  55,  26],  
        [ 64,  73,  10,  40, 107],  
        [ 18,  62,  11,  36,  89],  
        [ 91, 109,   0,  88, 104]], 
                                    
       [[ 65,  45,  31,  70,  42],  
        [ 12,  15, 114,  76,  97],  
        [ 24,  78,  22,  96,  56],  
        [110,  30,  53, 118,   9]], 
                                    
       [[ 33,  25,  69,  28,  98],  
        [ 85,   5,  90,  68,  39],  
        [ 49,  35,  16,  66,  34],  
        [113,   7,  43,  72,  67]]])
>>> b[1]                            
array([[[ 83,  27,  19,  95, 100],  
        [  8,  13,  84,   3,  17],  
        [ 38, 117,   6,  77, 111],  
        [ 94,  54,  50,  80,  46]], 
                                    
       [[ 81,  61, 116,  79,  93],  
        [ 41,  58,  48, 101,  57],  
        [ 75,  32, 112,  59,  63],  
        [105,  37,  29, 115,   1]], 
                                    
       [[ 52,  21,   2,  23, 103],  
        [ 99,  87, 108,  74,  86],  
        [ 82, 119,  20,  60,  71],  
        [106,  14,  92,  51, 102]]])

而第1维b.max(axis=0)的过程实际上就是每个子矩阵中相同位置选最大的值。比如在b[0]中第1行第1个元素是44,而b[1]相同位置上元素是83,因此第1行第1个元素自然是83。同理第1行第2个元素是47,以此类推最后1个元素是102,从而得到最终的结果。因此其结果为:

>>> b.max(axis=0)
array([[[ 83,  47,  19,  95, 100],
        [ 64,  73,  84,  40, 107],
        [ 38, 117,  11,  77, 111],
        [ 94, 109,  50,  88, 104]],

       [[ 81,  61, 116,  79,  93],
        [ 41,  58, 114, 101,  97],
        [ 75,  78, 112,  96,  63],
        [110,  37,  53, 118,   9]],

       [[ 52,  25,  69,  28, 103],
        [ 99,  87, 108,  74,  86],
        [ 82, 119,  20,  66,  71],
        [113,  14,  92,  72, 102]]])

那么按照第2维度进行max处理呢,此时我们取出b[0]中每一组第1行:

>>> b[0][:,0]
array([[44, 47,  4, 55, 26],
       [65, 45, 31, 70, 42],
       [33, 25, 69, 28, 98]])

那么第2维的比较就是相应44,65,33中谁的值最大,自然是65获胜了。接着下1个是47,再下一个是69,以此类推得到b[0]中每一组的最大值。同理b[1]中也进行相同的操作。

>>> b.max(axis=1)
array([[[ 65,  47,  69,  70,  98],
        [ 85,  73, 114,  76, 107],
        [ 49,  78,  22,  96,  89],
        [113, 109,  53, 118, 104]],

       [[ 83,  61, 116,  95, 103],
        [ 99,  87, 108, 101,  86],
        [ 82, 119, 112,  77, 111],
        [106,  54,  92, 115, 102]]])

之后是第3维的比较,那就继续分解下去,于是有:

>>> b[0][0]
array([[ 44,  47,   4,  55,  26],
       [ 64,  73,  10,  40, 107],
       [ 18,  62,  11,  36,  89],
       [ 91, 109,   0,  88, 104]])

此时要比较的就是从44,64,18,91中找出最大值,自然91是最大的,下一个元素是109,再下一个是11,之后是88和107了。因此其结果为:

>>> b.max(axis=2)
array([[[ 91, 109,  11,  88, 107],
        [110,  78, 114, 118,  97],
        [113,  35,  90,  72,  98]],

       [[ 94, 117,  84,  95, 111],
        [105,  61, 116, 115,  93],
        [106, 119, 108,  74, 103]]])

最后是第4维最值的比较,直接就是比较每行中每列值谁最大了:

>>> b[0][0][0]
array([44, 47,  4, 55, 26])
>>> b[0][0][1]
array([ 64,  73,  10,  40, 107])

可以看到其最值为55,之后是107。因此其结果为:

>>> b.max(axis=3)
array([[[ 55, 107,  89, 109],
        [ 70, 114,  96, 118],
        [ 98,  90,  66, 113]],

       [[100,  84, 117,  94],
        [116, 101, 112, 115],
        [103, 108, 119, 106]]])

换句话说整个过程就是要找到需要比较的组,逐步分解从而可以轻松得到结果。
同样该方法也可以用于mean、min等函数中。至于其结果的维度,可以将shape值对应位的值去掉即可,如:

>>> b.max(axis=0).shape == b.shape[1:]
True
>>> b.max(axis=1).shape == b.shape[:1] + b.shape[2:]
True
...
True

原因很简单,因为对应维度都用于比较了,结果对应维度就降维了,直接被忽略了。

posted @ 2025-04-21 21:52  月薪几千的牛马  阅读(32)  评论(0)    收藏  举报