sample(): random sampling from the probability distribution. So, we cannot backpropagate, because it is random! (the computation graph is cut off).
See the source code of sample in torch.distributions.normal.Normal:
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
with torch.no_grad():
return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
torch.normal returns a tensor of random numbers. Also, torch.no_grad() context prevents the computation graph from growing any further.
You see, we cannot backprop. The returned tensor of sample() contains just some numbers, not the whole computational graph.
So, what is rsample()?
By using rsample, we can backpropagate, because it keeps the computation graph alive.
How? By putting the randomness aside in a separate parameter. This is called the “reparameterization trick”.
rsample: sampling using reparameterization trick.
There is eps in the source code:
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
return self.loc + eps * self.scale
# `self.loc` is the mean and `self.scale` is the standard deviation.
eps is the separate parameter responsible for the randomness of the sampling.
Look at the return: mean + eps * standard deviation
eps does not depend on the parameters you want to differentiate with respect to.
So, now you can freely backpropagate(=differentiate) because eps does not change when the parameters change.
(If we change the parameters, the distribution of the reparameterized samples does change because self.loc and self.scale change, but the distribution of the eps does not change.)
Note that the randomness of the sampling comes from the random sampling of the eps. There is no randomness in the computation graph itself. Once eps is chosen, it is fixed. (the distribution of the elements of the eps is fixed, after they are sampled.)
For example, in an implementation of the SAC(Soft Actor-Critic) algorithm in reinforcement learning, eps may consist of elements corresponding to a single minibatch of actions (and one action may consist of many elements).