Pytorch实现VAE

变分自编码器Pytorch实现。

 1 class VAE(nn.Module):
 2     def __init__(self):
 3         super(VAE, self).__init__()
 4 
 5         self.fc1 = nn.Linear(784, 400)
 6         self.fc21 = nn.Linear(400, 20)
 7         self.fc22 = nn.Linear(400, 20)
 8         self.fc3 = nn.Linear(20, 400)
 9         self.fc4 = nn.Linear(400, 784)
10 
11     def encode(self, x):
12         h1 = F.relu(self.fc1(x))
13         return self.fc21(h1), self.fc22(h1)
14 
15     def reparameterize(self, mu, logvar):
16         std = torch.exp(0.5*logvar)
17         eps = torch.randn_like(std)
18         return mu + eps*std
19 
20     def decode(self, z):
21         h3 = F.relu(self.fc3(z))
22         return torch.sigmoid(self.fc4(h3))
23 
24     def forward(self, x):
25         mu, logvar = self.encode(x.view(-1, 784))
26         z = self.reparameterize(mu, logvar)
27         return self.decode(z), mu, logvar
28     
29     def loss_function_original(recon_x, x, mu, logvar):
30         BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
31         KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
32         return BCE + KLD

 

CVAE:

https://www.cnblogs.com/amazingter/p/14696251.html

https://www.cnblogs.com/boyknight/p/16290582.html

https://baileyswu.github.io/2019/11/disentangling-disentanglement-in-vae/

https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/116246208

posted @ 2022-09-25 13:04  zxcayumi  阅读(200)  评论(0编辑  收藏  举报