深度学习(matmul替换einsum)

onnx对于einsum算子是在opset12之后才支持的,但是有些芯片对于onnx量化的支持只到opset11版本,遇到这种情况可以使用matmul替换einsum。

流程通常是将tensor先view成三维,然后将第一个tensor待消掉维度permute到最后一维,第二个tensor待消掉的维度permute到倒数第二维,matmul相乘之后再view成最终维度即可。

下面是两个例子。

import torch

b,d,n,m = 1,256,123,234
desc0 = torch.randn(b, d, n)
desc1 = torch.randn(b, d, m)

scores1 = torch.einsum('bdn,bdm->bnm', desc0, desc1)
scores2 = torch.matmul(desc0.transpose(1,2), desc1)  

print(torch.allclose(scores1, scores2))

b,d,h,n,m = 2,3,4,50,60
desc0 = torch.randn(b, d, h, n)  
desc1 = torch.randn(b, d, h, m)
scores1 = torch.einsum('bdhn,bdhm->bhnm', desc0, desc1)

desc0 = desc0.permute(0, 2, 3, 1).contiguous()  # (b, h, n, d)
desc1 = desc1.permute(0, 2, 1, 3).contiguous()  # (b, h, d, m)

desc0 = desc0.view(b*h, n, d)  
desc1 = desc1.view(b*h, d, m) 

scores2 = torch.matmul(desc0, desc1).view(b,h,n,m) # (b, h, n, m)

print(torch.allclose(scores1, scores2))
posted @ 2025-01-26 22:31  Dsp Tian  阅读(125)  评论(0)    收藏  举报