矩阵和 numpy.transpose

 
 
 
 
 
 
 
 
 
 

矩阵和 numpy.transpose

文章[Transformer源码详解(Pytorch版本) - 知乎 (zhihu.com)](https://zhuanlan.zhihu.com/p/398039366?utm_campaign=shareopn&utm_medium=social&utm_oi=1396930517548257280&utm_psn=1547199426296487936&utm_source=wechat_session)

代码 [harvardnlp/annotated-transformer: An annotated implementation of the Transformer paper. (github.com)](https://github.com/harvardnlp/annotated-transformer)

所引出的问题

部分受启发于 https://www.cnblogs.com/sunshinewang/p/6893503.html

转置有三种方式,transpose方法、T属性以及swapaxes方法。

矩阵

import numpy as np
x = np.arange(24).reshape((2,3,4))
print(x)
print(x.transpose((1,0,2))) # shape(3,2,4)

[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

[[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]
#############################################
[[[ 0  1  2  3]
  [12 13 14 15]]

 [[ 4  5  6  7]
  [16 17 18 19]]

 [[ 8  9 10 11]
  [20 21 22 23]]]

括号上 由外到内的一个层级

不要思考xyz了,直接用012的思路吧,012就是从内到外的一个矩阵层次划分,对应到矩阵表示中也是同理的

转置

主要是 numpy.transpose

主要是考虑角标? 毕竟矩阵表示中的 xyz 对应着每一个坐标的xyz,与数学中强调形状不同,计算机中矩阵的应用更强调于角标的变换,我个人的三维想象能力欠佳,所以只能以角标计算的方式理解代码中的矩阵转置

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

4行这个地方我目前倾向于写错了,首先这样的-1,-2的写法也需要满足维度,我用jupyter的实验中,维度不符合是不可以的。

下面是具体实验内容。

(-2,-1)

import numpy as np
x = np.arange(6).reshape(2,3)
print(x)
print("############################")
print(x.transpose(-2,-1)) 

[[0 1 2]
 [3 4 5]]
############################
[[0 1 2]
 [3 4 5]]

(-1,-2)

import numpy as np
x = np.arange(6).reshape(2,3)
print(x)
print("############################")
print(x.transpose(-1,-2))

[[0 1 2]
 [3 4 5]]
############################
[[0 3]
 [1 4]
 [2 5]]

另外在代码中的与qkv计算相关的部分

        query, key, value = [
            lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
            for lin, x in zip(self.linears, (query, key, value))
        ]

转折

后经查实 函数中使用的是torch.transpose 这个函数就是简单的换位,并不用知道是准确的哪一些。

posted @ 2022-08-28 13:22  CCCarloooo  阅读(91)  评论(0)    收藏  举报