循环相关 pytorch

# pytorch 1.6 目前pytorch.fft没有支持复数数据类型,如果支持就好写很多了
# 当onesided为True时,依据输入张量的嵌入维度奇偶性,可能出现维度不匹配的问题

from einops impor rearrange

def ccorr(a, b, sigdim=1, onesided = False):
    """
    在a,b最后一个维度上做循环相关
    a \ast b = \mathcal{F}^{-1}(\overline{\mathcal{F}(a)} \odot \mathcal{F}(b))
    Parameter
    ---------
    a: real valued array shape (*, N)
    b: real valued array shape (*, N)
    Returns
    -------
    c: real valued array (shape (*,N)), representingthe circular
       correlation of a and b
    """
    real1,imag1 = rearrange(torch.rfft(a,sigdim,onesided=onesided),'b embed c2 -> c2 b embed')  # 目前torch.rfft返回一个形状为(*,2)的浮点数张量,那个2是因为有实数部分和虚数部分
    imag1=-imag1
    real2,imag2 = rearrange(torch.rfft(b,sigdim,onesided=onesided),'b embed c2 -> c2 b embed')
    c = torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim = -1)
    return torch.irfft(c,sigdim,onesided=onesided)

 

posted @ 2020-10-15 14:43  e-yi  阅读(18)  评论(0)    收藏  举报  来源