# pytorch 1.6 目前pytorch.fft没有支持复数数据类型,如果支持就好写很多了
# 当onesided为True时,依据输入张量的嵌入维度奇偶性,可能出现维度不匹配的问题
from einops impor rearrange
def ccorr(a, b, sigdim=1, onesided = False):
"""
在a,b最后一个维度上做循环相关
a \ast b = \mathcal{F}^{-1}(\overline{\mathcal{F}(a)} \odot \mathcal{F}(b))
Parameter
---------
a: real valued array shape (*, N)
b: real valued array shape (*, N)
Returns
-------
c: real valued array (shape (*,N)), representingthe circular
correlation of a and b
"""
real1,imag1 = rearrange(torch.rfft(a,sigdim,onesided=onesided),'b embed c2 -> c2 b embed') # 目前torch.rfft返回一个形状为(*,2)的浮点数张量,那个2是因为有实数部分和虚数部分
imag1=-imag1
real2,imag2 = rearrange(torch.rfft(b,sigdim,onesided=onesided),'b embed c2 -> c2 b embed')
c = torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim = -1)
return torch.irfft(c,sigdim,onesided=onesided)