Source code for nnabla_rl.distributions.squashed_gaussian

# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

import nnabla as nn
import nnabla.functions as NF
import nnabla_rl.functions as RF
from nnabla_rl.distributions import common_utils
from nnabla_rl.distributions.distribution import ContinuosDistribution


[docs]class SquashedGaussian(ContinuosDistribution): """Gaussian distribution which its output is squashed with tanh. :math:`z \\sim \\mathcal{N}(\\mu,\\,\\sigma^{2})`. :math:`out = \\tanh{z}`. Args: mean (nn.Variable): mean :math:`\\mu` of underlying gaussian distribution. ln_var (nn.Variable): logarithm of the variance :math:`\\sigma^{2}`. (i.e. ln_var is :math:`\\log{\\sigma^{2}}`) Note: The log probability and kl_divergence of this distribution is different from :py:class:`Gaussian distribution <nnabla_rl.distributions.Gaussian>` because the output is squashed. """ def __init__(self, mean, ln_var): super(SquashedGaussian, self).__init__() if not isinstance(mean, nn.Variable): mean = nn.Variable.from_numpy_array(mean) if not isinstance(ln_var, nn.Variable): ln_var = nn.Variable.from_numpy_array(ln_var) self._mean = mean self._var = NF.exp(ln_var) self._ln_var = ln_var self._ndim = mean.shape[-1] @property def ndim(self): return self._ndim
[docs] def sample(self, noise_clip=None): x = RF.sample_gaussian(mean=self._mean, ln_var=self._ln_var, noise_clip=noise_clip) return NF.tanh(x)
[docs] def sample_multiple(self, num_samples, noise_clip=None): x = RF.sample_gaussian_multiple(self._mean, self._ln_var, num_samples, noise_clip=noise_clip) return NF.tanh(x)
[docs] def sample_and_compute_log_prob(self, noise_clip=None): """NOTE: In order to avoid sampling different random values for sample and log_prob, you'll need to use nnabla.forward_all(sample, log_prob) If you forward the two variables independently, you'll get a log_prob for different sample, since different random variables are sampled internally.""" x = RF.sample_gaussian( mean=self._mean, ln_var=self._ln_var, noise_clip=noise_clip) log_prob = self._log_prob_internal( x, self._mean, self._var, self._ln_var) return NF.tanh(x), log_prob
[docs] def sample_multiple_and_compute_log_prob(self, num_samples, noise_clip=None): """NOTE: In order to avoid sampling different random values for sample and log_prob, you'll need to use nnabla.forward_all(sample, log_prob) If you forward the two variables independently, you'll get a log_prob for different sample, since different random variables are sampled internally.""" x = RF.sample_gaussian_multiple(self._mean, self._ln_var, num_samples=num_samples, noise_clip=noise_clip) mean = RF.expand_dims(self._mean, axis=1) var = RF.expand_dims(self._var, axis=1) ln_var = RF.expand_dims(self._ln_var, axis=1) assert mean.shape == (x.shape[0], 1, x.shape[-1]) assert var.shape == mean.shape assert ln_var.shape == mean.shape log_prob = self._log_prob_internal(x, mean, var, ln_var) return NF.tanh(x), log_prob
[docs] def choose_probable(self): return NF.tanh(self._mean)
[docs] def log_prob(self, x): x = NF.atanh(x) return self._log_prob_internal(x, self._mean, self._var, self._ln_var)
def _log_prob_internal(self, x, mean, var, ln_var): axis = len(x.shape) - 1 gaussian_part = common_utils.gaussian_log_prob(x, mean, var, ln_var) adjust_part = NF.sum(self._log_determinant_jacobian(x), axis=axis, keepdims=True) return gaussian_part - adjust_part def _log_determinant_jacobian(self, x): # arctanh(y)' = 1/(1 - y^2) (y=tanh(x)) # Below computes log(1 - tanh(x)^2) # For derivation see: # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py return 2.0 * (np.log(2.0) - x - NF.softplus(-2.0 * x))