pytorch中的repeat和repeat_interleave

个人的简单理解:

repeat可以理解为多次复制张量后在指定维度上concate上去,即x.repeat(n,dim=k)等价成torch.cat([x for _ in range(n)],dim=k)
repeat_interleave实际上等价于repeat在高一维的基础上运算后再view,即x.repeat_interleave(n,dim=k)等价成x.repeat(n,dim=k+1).view(N0, N1, ..., n*Nk, Nk+1, ...),其中N0,N1, Nk, Nk+1分别指x的第01kk+1维的长度。当k是最后一维时自动unsqueeze(-1)

posted @ 2024-07-04 10:12  kksk43  阅读(325)  评论(0)    收藏  举报
特效
黑夜
侧边栏隐藏