reparameters vs no reparameters

def sample(self, return_pretanh_value=False):
"""
Gradients will and should *not* pass through this operation.

See https://github.com/pytorch/pytorch/issues/4620 for discussion.
"""
z = self.normal.sample().detach()

if return_pretanh_value:
return torch.tanh(z), z
else:
return torch.tanh(z)

def rsample(self, return_pretanh_value=False):
"""
Sampling in the reparameterization case.
"""
z = (
self.normal_mean +
self.normal_std *
Normal(
torch.zeros(self.normal_mean.size(), device=self.normal_mean.device),
torch.ones(self.normal_std.size(), device=self.normal_mean.device)
).sample()
)
z.requires_grad_()

if return_pretanh_value:
return torch.tanh(z), z
else:
return torch.tanh(z)


from :
offlinerl/neorl
posted @ 2022-04-27 20:32  呦呦南山  阅读(36)  评论(0)    收藏  举报