深度学习(matmul替换einsum)
onnx对于einsum算子是在opset12之后才支持的,但是有些芯片对于onnx量化的支持只到opset11版本,遇到这种情况可以使用matmul替换einsum。
流程通常是将tensor先view成三维,然后将第一个tensor待消掉维度permute到最后一维,第二个tensor待消掉的维度permute到倒数第二维,matmul相乘之后再view成最终维度即可。
下面是两个例子。
import torch b,d,n,m = 1,256,123,234 desc0 = torch.randn(b, d, n) desc1 = torch.randn(b, d, m) scores1 = torch.einsum('bdn,bdm->bnm', desc0, desc1) scores2 = torch.matmul(desc0.transpose(1,2), desc1) print(torch.allclose(scores1, scores2)) b,d,h,n,m = 2,3,4,50,60 desc0 = torch.randn(b, d, h, n) desc1 = torch.randn(b, d, h, m) scores1 = torch.einsum('bdhn,bdhm->bhnm', desc0, desc1) desc0 = desc0.permute(0, 2, 3, 1).contiguous() # (b, h, n, d) desc1 = desc1.permute(0, 2, 1, 3).contiguous() # (b, h, d, m) desc0 = desc0.view(b*h, n, d) desc1 = desc1.view(b*h, d, m) scores2 = torch.matmul(desc0, desc1).view(b,h,n,m) # (b, h, n, m) print(torch.allclose(scores1, scores2))

浙公网安备 33010602011771号