def sample(self, return_pretanh_value=False):
"""
Gradients will and should *not* pass through this operation.
See https://github.com/pytorch/pytorch/issues/4620 for discussion.
"""
z = self.normal.sample().detach()
if return_pretanh_value:
return torch.tanh(z), z
else:
return torch.tanh(z)
def rsample(self, return_pretanh_value=False):
"""
Sampling in the reparameterization case.
"""
z = (
self.normal_mean +
self.normal_std *
Normal(
torch.zeros(self.normal_mean.size(), device=self.normal_mean.device),
torch.ones(self.normal_std.size(), device=self.normal_mean.device)
).sample()
)
z.requires_grad_()
if return_pretanh_value:
return torch.tanh(z), z
else:
return torch.tanh(z)
from :
offlinerl/neorl