pytorch中的repeat和repeat_interleave
个人的简单理解:
repeat可以理解为多次复制张量后在指定维度上concate上去,即x.repeat(n,dim=k)等价成torch.cat([x for _ in range(n)],dim=k)
repeat_interleave实际上等价于repeat在高一维的基础上运算后再view,即x.repeat_interleave(n,dim=k)等价成x.repeat(n,dim=k+1).view(N0, N1, ..., n*Nk, Nk+1, ...),其中N0,N1, Nk, Nk+1分别指x的第0,1,k,k+1维的长度。当k是最后一维时自动unsqueeze(-1)

浙公网安备 33010602011771号