# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABCMeta, abstractmethod
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
import nnabla as nn
import nnabla.functions as NF
import nnabla_rl.functions as RF
from nnabla_rl.models.model import Model
from nnabla_rl.models.q_function import QFunction
[docs]class ValueDistributionFunction(Model, metaclass=ABCMeta):
"""Base value distribution class.
Computes the probabilities of q-value for each action.
Value distribution function models the probabilities of q value for each action by dividing
the values between the maximum q value and minimum q value into 'n_atom' number of bins and
assigning the probability to each bin.
Args:
scope_name (str): scope name of the model
n_action (int): Number of actions which used in target environment.
n_atom (int): Number of bins.
v_min (int): Minimum value of the distribution.
v_max (int): Maximum value of the distribution.
"""
# type declarations to type check with mypy
# NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar
# See https://mypy.readthedocs.io/en/stable/class_basics.html for details
_n_action: int
_n_atom: int
_v_min: float
_v_max: float
_z: nn.Variable
def __init__(self, scope_name: str, n_action: int, n_atom: int, v_min: float, v_max: float):
super(ValueDistributionFunction, self).__init__(scope_name)
self._n_action = n_action
self._n_atom = n_atom
self._v_min = v_min
self._v_max = v_max
# precompute atoms
self._z = self._compute_z(n_atom, v_min, v_max)
self._z.persistent = True
def __deepcopy__(self, memodict: Dict[Any, Any] = {}):
# nn.Variable cannot be deepcopied directly
return self.__class__(self._scope_name, self._n_action, self._n_atom, self._v_min, self._v_max)
[docs] @abstractmethod
def probs(self, s: nn.Variable, a: nn.Variable) -> nn.Variable:
"""Compute probabilities of atoms for given state and action.
Args:
s (nn.Variable): state variable
a (nn.Variable): action variable
Returns:
nn.Variable: probabilities of atoms for given state and action
"""
raise NotImplementedError
[docs] def all_probs(self, s: nn.Variable) -> nn.Variable:
"""Compute probabilities of atoms for all posible actions for given
state.
Args:
s (nn.Variable): state variable
Returns:
nn.Variable: probabilities of atoms for all posible actions for given state
"""
raise NotImplementedError
[docs] def max_q_probs(self, s: nn.Variable) -> nn.Variable:
"""Compute probabilities of atoms for given state that maximizes the
q_value.
Args:
s (nn.Variable): state variable
Returns:
nn.Variable: probabilities of atoms for given state that maximizes the q_value
"""
raise NotImplementedError
[docs] def as_q_function(self) -> QFunction:
"""Convert the value distribution function to QFunction.
Returns:
nnabla_rl.models.q_function.QFunction:
QFunction instance which computes the q-values based on the probabilities.
"""
raise NotImplementedError
def _compute_z(self, n_atom: int, v_min: float, v_max: float) -> nn.Variable:
delta_z = (v_max - v_min) / (n_atom - 1)
return v_min + delta_z * NF.arange(0, n_atom)
class DiscreteValueDistributionFunction(ValueDistributionFunction):
"""Base value distribution class for discrete action envs."""
@abstractmethod
def all_probs(self, s: nn.Variable) -> nn.Variable:
raise NotImplementedError
def probs(self, s: nn.Variable, a: nn.Variable) -> nn.Variable:
probs = self.all_probs(s)
return self._probabilities_of(probs, a)
def max_q_probs(self, s: nn.Variable) -> nn.Variable:
probs = self.all_probs(s)
a_star = self._argmax_q_from_probabilities(probs)
return self._probabilities_of(probs, a_star)
def as_q_function(self) -> QFunction:
class Wrapper(QFunction):
_value_distribution_function: "DiscreteValueDistributionFunction"
def __init__(self, value_distribution_function: "DiscreteValueDistributionFunction"):
super(Wrapper, self).__init__(value_distribution_function.scope_name)
self._value_distribution_function = value_distribution_function
def q(self, s: nn.Variable, a: nn.Variable) -> nn.Variable:
q_values = self._value_distribution_function._state_to_q_values(s)
one_hot = NF.one_hot(NF.reshape(a, (-1, 1), inplace=False), (q_values.shape[1],))
q_value = NF.sum(q_values * one_hot, axis=1, keepdims=True) # get q value of a
return q_value
def max_q(self, s: nn.Variable) -> nn.Variable:
q_values = self._value_distribution_function._state_to_q_values(s)
return NF.max(q_values, axis=1, keepdims=True)
def argmax_q(self, s: nn.Variable) -> nn.Variable:
probabilities = self._value_distribution_function.all_probs(s)
greedy_action = self._value_distribution_function._argmax_q_from_probabilities(probabilities)
return greedy_action
def is_recurrent(self) -> bool:
return self._value_distribution_function.is_recurrent()
def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]:
return self._value_distribution_function.internal_state_shapes()
def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None):
return self._value_distribution_function.set_internal_states(states)
def get_internal_states(self) -> Dict[str, nn.Variable]:
return self._value_distribution_function.get_internal_states()
return Wrapper(self)
def _argmax_q_from_probabilities(self, atom_probabilities: nn.Variable) -> nn.Variable:
q_values = self._probabilities_to_q_values(atom_probabilities)
return RF.argmax(q_values, axis=1, keepdims=True)
def _state_to_q_values(self, s: nn.Variable) -> nn.Variable:
probabilities = self.all_probs(s)
return self._probabilities_to_q_values(probabilities)
def _probabilities_of(self, probabilities: nn.Variable, a: nn.Variable) -> nn.Variable:
probabilities = NF.transpose(probabilities, axes=(0, 2, 1))
one_hot = self._to_one_hot(a)
probabilities = probabilities * one_hot
probabilities = NF.sum(probabilities, axis=2)
return probabilities
def _probabilities_to_q_values(self, atom_probabilities: nn.Variable) -> nn.Variable:
batch_size = atom_probabilities.shape[0]
assert atom_probabilities.shape == (batch_size, self._n_action, self._n_atom)
z = RF.expand_dims(self._z, axis=0)
z = RF.expand_dims(z, axis=1)
z = NF.broadcast(z, shape=(batch_size, self._n_action, self._n_atom))
q_values = NF.sum(z * atom_probabilities, axis=2)
assert q_values.shape == (batch_size, self._n_action)
return q_values
def _to_one_hot(self, a: nn.Variable) -> nn.Variable:
batch_size = a.shape[0]
a = NF.reshape(a, (-1, 1))
assert a.shape[0] == batch_size
one_hot = NF.one_hot(a, (self._n_action,))
one_hot = RF.expand_dims(one_hot, axis=1)
one_hot = NF.broadcast(one_hot, shape=(batch_size, self._n_atom, self._n_action))
return one_hot
class ContinuousValueDistributionFunction(ValueDistributionFunction):
"""Base value distribution class for continuous action envs."""
pass
[docs]class QuantileDistributionFunction(Model, metaclass=ABCMeta):
"""Base quantile distribution class.
Computes the quantiles of q-value for each action.
Quantile distribution function models the quantiles of q value for each action by dividing
the probability (which is between 0.0 and 1.0) into 'n_quantile' number of bins and
assigning the n-quantile to n-th bin.
Args:
scope_name (str): scope name of the model
n_quantile (int): Number of bins.
"""
# type declarations to type check with mypy
# NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar
# See https://mypy.readthedocs.io/en/stable/class_basics.html for details
_n_quantile: int
_qj: float
def __init__(self, scope_name: str, n_quantile: int):
super(QuantileDistributionFunction, self).__init__(scope_name)
self._n_quantile = n_quantile
self._qj = 1 / n_quantile
[docs] def all_quantiles(self, s: nn.Variable) -> nn.Variable:
"""Computes the quantiles of q-value for each action for the given
state.
Args:
s (nn.Variable): state variable
Returns:
nn.Variable: quantiles of q-value for each action for the given state
"""
raise NotImplementedError
[docs] def quantiles(self, s: nn.Variable, a: nn.Variable) -> nn.Variable:
"""Computes the quantiles of q-value for given state and action.
Args:
s (nn.Variable): state variable
a (nn.Variable): action variable
Returns:
nn.Variable: quantiles of q-value for given state and action.
"""
raise NotImplementedError
[docs] def max_q_quantiles(self, s: nn.Variable) -> nn.Variable:
"""Compute the quantiles of q-value for given state that maximizes the
q_value.
Args:
s (nn.Variable): state variable
Returns:
nn.Variable: quantiles of q-value for given state that maximizes the q_value
"""
raise NotImplementedError
[docs] def as_q_function(self) -> QFunction:
"""Convert the quantile distribution function to QFunction.
Returns:
nnabla_rl.models.q_function.QFunction:
QFunction instance which computes the q-values based on the quantiles.
"""
raise NotImplementedError
class DiscreteQuantileDistributionFunction(QuantileDistributionFunction):
# type declarations to type check with mypy
# NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar
# See https://mypy.readthedocs.io/en/stable/class_basics.html for details
_n_action: int
def __init__(self, scope_name: str, n_action: int, n_quantile: int):
super().__init__(scope_name, n_quantile)
self._n_action = n_action
@abstractmethod
def all_quantiles(self, s: nn.Variable) -> nn.Variable:
raise NotImplementedError
def quantiles(self, s: nn.Variable, a: nn.Variable) -> nn.Variable:
quantiles = self.all_quantiles(s)
return self._quantiles_of(quantiles, a)
def max_q_quantiles(self, s: nn.Variable) -> nn.Variable:
probs = self.all_quantiles(s)
a_star = self._argmax_q_from_quantiles(probs)
return self._quantiles_of(probs, a_star)
def as_q_function(self) -> QFunction:
class Wrapper(QFunction):
_quantile_distribution_function: "DiscreteQuantileDistributionFunction"
def __init__(self, quantile_distribution_function: "DiscreteQuantileDistributionFunction"):
super(Wrapper, self).__init__(quantile_distribution_function.scope_name)
self._quantile_distribution_function = quantile_distribution_function
def q(self, s: nn.Variable, a: nn.Variable) -> nn.Variable:
q_values = self._quantile_distribution_function._state_to_q_values(s)
one_hot = NF.one_hot(NF.reshape(a, (-1, 1), inplace=False), (q_values.shape[1],))
q_value = NF.sum(q_values * one_hot, axis=1, keepdims=True) # get q value of a
return q_value
def max_q(self, s: nn.Variable) -> nn.Variable:
q_values = self._quantile_distribution_function._state_to_q_values(s)
return NF.max(q_values, axis=1, keepdims=True)
def argmax_q(self, s: nn.Variable) -> nn.Variable:
quantiles = self._quantile_distribution_function.all_quantiles(s)
greedy_action = self._quantile_distribution_function._argmax_q_from_quantiles(quantiles)
return greedy_action
def is_recurrent(self) -> bool:
return self._quantile_distribution_function.is_recurrent()
def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]:
return self._quantile_distribution_function.internal_state_shapes()
def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None):
return self._quantile_distribution_function.set_internal_states(states)
def get_internal_states(self) -> Dict[str, nn.Variable]:
return self._quantile_distribution_function.get_internal_states()
return Wrapper(self)
def _argmax_q_from_quantiles(self, quantiles: nn.Variable) -> nn.Variable:
q_values = self._quantiles_to_q_values(quantiles)
return RF.argmax(q_values, axis=1, keepdims=True)
def _quantiles_to_q_values(self, quantiles: nn.Variable) -> nn.Variable:
return NF.sum(quantiles * self._qj, axis=2)
def _state_to_q_values(self, s: nn.Variable) -> nn.Variable:
quantiles = self.all_quantiles(s)
return self._quantiles_to_q_values(quantiles)
def _quantiles_of(self, quantiles: nn.Variable, a: nn.Variable) -> nn.Variable:
batch_size = quantiles.shape[0]
quantiles = NF.transpose(quantiles, axes=(0, 2, 1))
one_hot = self._to_one_hot(a)
quantiles = quantiles * one_hot
quantiles = NF.sum(quantiles, axis=2)
assert quantiles.shape == (batch_size, self._n_quantile)
return quantiles
def _to_one_hot(self, a: nn.Variable) -> nn.Variable:
batch_size = a.shape[0]
a = NF.reshape(a, (-1, 1))
assert a.shape[0] == batch_size
one_hot = NF.one_hot(a, (self._n_action,))
one_hot = RF.expand_dims(one_hot, axis=1)
one_hot = NF.broadcast(one_hot, shape=(batch_size, self._n_quantile, self._n_action))
return one_hot
class ContinuousQuantileDistributionFunction(QuantileDistributionFunction):
def as_q_function(self) -> QFunction:
class Wrapper(QFunction):
_quantile_distribution_function: "ContinuousQuantileDistributionFunction"
def __init__(self, quantile_distribution_function: "ContinuousQuantileDistributionFunction"):
super(Wrapper, self).__init__(quantile_distribution_function.scope_name)
self._quantile_distribution_function = quantile_distribution_function
def q(self, s: nn.Variable, a: nn.Variable) -> nn.Variable:
quantiles = self._quantile_distribution_function.quantiles(s, a)
return NF.mean(quantiles, axis=len(quantiles.shape) - 1, keepdims=True)
def max_q(self, s: nn.Variable) -> nn.Variable:
raise NotImplementedError
def argmax_q(self, s: nn.Variable) -> nn.Variable:
raise NotImplementedError
def is_recurrent(self) -> bool:
return self._quantile_distribution_function.is_recurrent()
def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]:
return self._quantile_distribution_function.internal_state_shapes()
def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None):
return self._quantile_distribution_function.set_internal_states(states)
def get_internal_states(self) -> Dict[str, nn.Variable]:
return self._quantile_distribution_function.get_internal_states()
return Wrapper(self)
def risk_neutral_measure(tau: nn.Variable) -> nn.Variable:
return tau
[docs]class StateActionQuantileFunction(Model, metaclass=ABCMeta):
"""State-action quantile function class.
Computes the return samples of q-value for each action.
State-action quantile function computes the return samples of q value for each action
using sampled quantile threshold (e.g. :math:`\\tau\\sim U([0,1])`) for given state.
Args:
scope_name (str): scope name of the model
n_action (int): Number of actions which used in target environment.
K (int): Number of samples for quantile threshold :math:`\\tau`.
risk_measure_function (Callable[[nn.Variable], nn.Variable]): Risk measure funciton which
modifies the weightings of tau. Defaults to risk neutral measure which does not do any change to the taus.
"""
# type declarations to type check with mypy
# NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar
# See https://mypy.readthedocs.io/en/stable/class_basics.html for details
_n_action: int
_K: int
# _risk_measure_funciton: Callable[[nn.Variable], nn.Variable]
def __init__(
self,
scope_name: str,
n_action: int,
K: int,
risk_measure_function: Callable[[nn.Variable], nn.Variable] = risk_neutral_measure,
):
super(StateActionQuantileFunction, self).__init__(scope_name)
self._n_action = n_action
self._K = K
self._risk_measure_function = risk_measure_function
[docs] def all_quantile_values(self, s: nn.Variable, tau: nn.Variable) -> nn.Variable:
"""Compute the return samples for all action for given state and
quantile threshold.
Args:
s (nn.Variable): state variable.
tau (nn.Variable): quantile threshold.
Returns:
nn.Variable: return samples from implicit return distribution for given state using tau.
"""
pass
[docs] def quantile_values(self, s: nn.Variable, a: nn.Variable, tau: nn.Variable) -> nn.Variable:
"""Compute the return samples for given state and action.
Args:
s (nn.Variable): state variable.
a (nn.Variable): action variable.
tau (nn.Variable): quantile threshold.
Returns:
nn.Variable: return samples from implicit return distribution for given state and action using tau.
"""
pass
[docs] def max_q_quantile_values(self, s: nn.Variable, tau: nn.Variable) -> nn.Variable:
"""Compute the return samples from distribution that maximizes q value
for given state using quantile threshold.
Args:
s (nn.Variable): state variable.
tau (nn.Variable): quantile threshold.
Returns:
nn.Variable: return samples from implicit return distribution that maximizes q for given state using tau.
"""
pass
[docs] def sample_tau(self, shape: Optional[Iterable] = None) -> nn.Variable:
"""Sample quantile thresholds from uniform distribution.
Args:
shape (Tuple[int] or None): shape of the quantile threshold to sample. If None the shape will be (1, K).
Returns:
nn.Variable: quantile thresholds
"""
if shape is None:
shape = (1, self._K)
return NF.rand(low=0.0, high=1.0, shape=shape)
[docs] def as_q_function(self) -> QFunction:
"""Convert the state action quantile function to QFunction.
Returns:
nnabla_rl.models.q_function.QFunction:
QFunction instance which computes the q-values based on return samples.
"""
raise NotImplementedError
def _sample_risk_measured_tau(self, shape: Optional[Iterable]) -> nn.Variable:
tau = self.sample_tau(shape)
return self._risk_measure_function(tau)
class DiscreteStateActionQuantileFunction(StateActionQuantileFunction):
@abstractmethod
def all_quantile_values(self, s: nn.Variable, tau: nn.Variable) -> nn.Variable:
raise NotImplementedError
def quantile_values(self, s: nn.Variable, a: nn.Variable, tau: nn.Variable) -> nn.Variable:
return_samples = self.all_quantile_values(s, tau)
return self._return_samples_of(return_samples, a)
def max_q_quantile_values(self, s: nn.Variable, tau: nn.Variable) -> nn.Variable:
if self.is_recurrent():
raise RuntimeError("max_q_quantile_values should be reimplemented in inherited class to support RNN layers")
batch_size = s.shape[0]
tau_k = self._sample_risk_measured_tau(shape=(batch_size, self._K))
# This implementation does not support RNNs because internal state will be overwritten by
# the second call of all_quantile_values()
_return_samples = self.all_quantile_values(s, tau_k)
a_star = self._argmax_q_from_return_samples(_return_samples)
# This will overwrite the internal state. So may not properly trained if this is called during training.
return_samples = self.all_quantile_values(s, tau)
return self._return_samples_of(return_samples, a_star)
def as_q_function(self) -> QFunction:
"""Convert the state action quantile function to QFunction.
Returns:
nnabla_rl.models.q_function.QFunction:
QFunction instance which computes the q-values based on the return_samples.
"""
class Wrapper(QFunction):
_quantile_function: "DiscreteStateActionQuantileFunction"
def __init__(self, quantile_function: "DiscreteStateActionQuantileFunction"):
super(Wrapper, self).__init__(quantile_function.scope_name)
self._quantile_function = quantile_function
def q(self, s: nn.Variable, a: nn.Variable) -> nn.Variable:
q_values = self.all_q(s)
one_hot = NF.one_hot(NF.reshape(a, (-1, 1), inplace=False), (q_values.shape[1],))
q_value = NF.sum(q_values * one_hot, axis=1, keepdims=True) # get q value of a
return q_value
def all_q(self, s: nn.Variable) -> nn.Variable:
return self._quantile_function._state_to_q_values(s)
def max_q(self, s: nn.Variable) -> nn.Variable:
q_values = self._quantile_function._state_to_q_values(s)
return NF.max(q_values, axis=1, keepdims=True)
def argmax_q(self, s: nn.Variable) -> nn.Variable:
batch_size = s.shape[0]
tau = self._quantile_function._sample_risk_measured_tau(shape=(batch_size, self._quantile_function._K))
samples = self._quantile_function.all_quantile_values(s, tau)
greedy_action = self._quantile_function._argmax_q_from_return_samples(samples)
return greedy_action
def is_recurrent(self) -> bool:
return self._quantile_function.is_recurrent()
def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]:
return self._quantile_function.internal_state_shapes()
def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None):
return self._quantile_function.set_internal_states(states)
def get_internal_states(self) -> Dict[str, nn.Variable]:
return self._quantile_function.get_internal_states()
return Wrapper(self)
def _return_samples_to_q_values(self, return_samples: nn.Variable) -> nn.Variable:
"""Compute the q values for each action for given return samples.
Args:
return_samples (nn.Variable): return samples.
Returns:
nn.Variable: q values for each action for given return samples.
"""
samples = NF.transpose(return_samples, axes=(0, 2, 1))
q_values = NF.mean(samples, axis=2)
return q_values
def _argmax_q_from_return_samples(self, return_samples: nn.Variable) -> nn.Variable:
"""Compute the action which maximizes the q value computed from given
return samples.
Args:
return_samples (nn.Variable): return samples.
Returns:
nn.Variable: action which maximizes the q value for given return samples.
"""
q_values = self._return_samples_to_q_values(return_samples)
return RF.argmax(q_values, axis=1, keepdims=True)
def _state_to_q_values(self, s: nn.Variable) -> nn.Variable:
tau = self._sample_risk_measured_tau(shape=(1, self._K))
samples = self.all_quantile_values(s, tau)
return self._return_samples_to_q_values(samples)
def _return_samples_of(self, return_samples: nn.Variable, a: nn.Variable) -> nn.Variable:
one_hot = self._to_one_hot(a, shape=return_samples.shape)
samples = return_samples * one_hot
samples = NF.sum(samples, axis=2)
assert len(samples.shape) == 2
return samples
def _to_one_hot(self, a: nn.Variable, shape: nn.Variable) -> nn.Variable:
a = NF.reshape(a, (-1, 1))
one_hot = NF.one_hot(a, (self._n_action,))
one_hot = RF.expand_dims(one_hot, axis=1)
one_hot = NF.broadcast(one_hot, shape=shape)
return one_hot
class ContinuousStateActionQuantileFunction(StateActionQuantileFunction):
pass