vqvae的loss计算

loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
z_q是codebook 找到的最接近z的向量.
z是encoder生成的向量.
L对z求导 = 2(z_q.detach()-z)*(-1)=2(z - z_q.detach())     # 这个部分对于encoder做了训练.
L对z_q求导=2(z_q - z.detach())                                       #这个部分对于codebook做了训练.
所以这个detach对于变量x虽然对x不求导,但是计算其他变量时候参与计算.
 
很早之前,在RVQ那篇文章里说到过,VQ-VAE中是通过在codebook中选择欧式距离最近的embedding对应的index作为离散token的。即其中涉及到argmin操作,该操作是不可导的。因此重建loss的梯度是无法传递到encoder网络的。
如果我们写成 loss= torch.mean((z-z_q)**2)
那么L对z_q求导=2(z_q-z)
对z求导=2(z_q-z)*(-1)=2(z-z_q). 这里面的两个导数是算不了的.因为argmin不可导. 导数没法从z_q传到变量x上.(x是输入网络参数)
所以我们只能用上面的方法来计算.
 
我们上面的方法.
L对z求导 = 2(z_q.detach()-z)*(-1)=2(z - z_q.detach())     # 这个部分对于encoder做了训练. 
L对z_q求导=2(z_q - z.detach())                                       #这个部分对于codebook做了训练. 我们再计算z_q 对x求导.即可. z是没法对x求导的.
 
参考这个: https://zhuanlan.zhihu.com/p/644091516  讲的很好.
 
 

 把一个图像看做高维空间中的一个点. 首先把 h,w,c 展开成一个向量. 比如是(a1,....an), 那么他可以看做n维空间中的一个点.

 
 

posted on 2023-11-24 17:28  张博的博客  阅读(535)  评论(0)    收藏  举报

导航