Models¶
All models are derived from nnabla_rl.models.Model
Model¶
- class nnabla_rl.models.model.Model(scope_name: str)[source]¶
Model Class.
- Parameters:
scope_name (str) – the scope name of model
- deepcopy(new_scope_name: str) T [source]¶
Deepcopy Create a (deep) copy of the model. All the model parameter’s (if exist) associated with will be copied and new_scope_name will be assigned.
- Parameters:
new_scope_name (str) – scope_name of parameters for newly created model
- Returns:
copied model
- Return type:
- Raises:
ValueError – Given scope name is same as the model or already exists.
- get_internal_states() Dict[str, Variable] [source]¶
get_internal states Get the internal state variable of rnn cell. Model which use LSTM, GRU and/or any other recurrent network component must implement this method.
- Returns:
Value of each internal state. key is the name of each internal state.
- Return type:
Dict[str, nn.Variable]
- get_parameters(grad_only: bool = True) Dict[str, Variable] [source]¶
Get_parameters Retrive parameters associated with this model.
- Parameters:
grad_only (bool) – Retrive parameters only with need_grad = True. Defaults to True.
- Returns:
Parameter map.
- Return type:
parameters (OrderedDict)
- internal_state_shapes() Dict[str, Tuple[int, ...]] [source]¶
Internal_state_shapes Return internal state shape as tuple of ints for each internal state (excluding the batch_size). This method will be called by (:py:class:`RNNModelTrainer.
<nnabla_rl.model_trainers.model_trainer.RNNModelTrainer>`) and its subclasses to setup training variables. Model which use LSTM, GRU and/or any other recurrent network component must implement this method.
- Returns:
internal state shapes. key is the name of each internal state.
- Return type:
Dict[str, Tuple[int, …]]
- is_recurrent() bool [source]¶
is_recurrent Check whether the model uses recurrent network component or not.
Model which use LSTM, GRU and/or any other recurrent network component must return True. :returns: True if the model uses recurrent network component. Otherwise False. :rtype: bool
- load_parameters(filepath: str | Path) None [source]¶
load_parameters Load model parameters from given filepath.
- Parameters:
filepath (str or pathlib.Path) – paramter file path
- reset_internal_states()[source]¶
reset_internal states Set the internal state variable of rnn cell to given zero.
- save_parameters(filepath: str | Path) None [source]¶
save_parameters Save model parameters to given filepath.
- Parameters:
filepath (str or pathlib.Path) – paramter file path
- property scope_name: str¶
scope_name Get scope name of this model.
- Returns:
scope name of the model
- Return type:
scope_name (str)
- set_internal_states(states: Dict[str, Variable] | None = None)[source]¶
set_internal states Set the internal state variable of rnn cell to given state.
Model which use LSTM, GRU and/or any other recurrent network component must implement this method. :param states: If None, reset all internal state to zero. :type states: None or Dict[str, nn.Variable] :param If state is provided: :param set the provided state as internal state.:
- shallowcopy() T [source]¶
Shallowcopy Create a (shallow) copy of the model. Unlike deepcopy, shallowcopy will KEEP sharing the original network parameter by using same scope_name as original model. However, all the class members will be (deep) copied to the new instance. Do NOT use this method unless you understand what this method does.
- Returns:
(shallow) copied model
- Return type:
List of Models¶
- class nnabla_rl.models.Perturbator(scope_name)[source]¶
Bases:
Model
DeterministicPolicy Abstract class for perturbator.
Perturbator generates noise to append to current state’s action
- class nnabla_rl.models.DeterministicPolicy(scope_name: str)[source]¶
Bases:
Policy
DeterministicPolicy Abstract class for deterministic policy.
This policy returns an action for the given state.
- class nnabla_rl.models.StochasticPolicy(scope_name: str)[source]¶
Bases:
Policy
StochasticPolicy Abstract class for stochastic policy.
This policy returns a probability distribution of action for the given state.
- abstract pi(s: Variable) Distribution [source]¶
- Parameters:
state (nnabla.Variable) – State variable
- Returns:
Probability distribution of the action for the given state
- Return type:
- class nnabla_rl.models.QFunction(scope_name: str)[source]¶
Bases:
Model
Base QFunction Class.
- all_q(s: Variable) Variable [source]¶
Compute Q-values for each action for given state.
- Parameters:
s (nn.Variable) – state variable
- Returns:
Q-values for each action for given state
- Return type:
nn.Variable
- argmax_q(s: Variable) Variable [source]¶
Compute the action which maximizes the Q-value for given state.
- Parameters:
s (nn.Variable) – state variable
- Returns:
action which maximizes the Q-value for given state
- Return type:
nn.Variable
- class nnabla_rl.models.ValueDistributionFunction(scope_name: str, n_action: int, n_atom: int, v_min: float, v_max: float)[source]¶
Bases:
Model
Base value distribution class.
Computes the probabilities of q-value for each action. Value distribution function models the probabilities of q value for each action by dividing the values between the maximum q value and minimum q value into ‘n_atom’ number of bins and assigning the probability to each bin.
- Parameters:
scope_name (str) – scope name of the model
n_action (int) – Number of actions which used in target environment.
n_atom (int) – Number of bins.
v_min (int) – Minimum value of the distribution.
v_max (int) – Maximum value of the distribution.
- all_probs(s: Variable) Variable [source]¶
Compute probabilities of atoms for all posible actions for given state.
- Parameters:
s (nn.Variable) – state variable
- Returns:
probabilities of atoms for all posible actions for given state
- Return type:
nn.Variable
- as_q_function() QFunction [source]¶
Convert the value distribution function to QFunction.
- Returns:
QFunction instance which computes the q-values based on the probabilities.
- Return type:
- class nnabla_rl.models.QuantileDistributionFunction(scope_name: str, n_quantile: int)[source]¶
Bases:
Model
Base quantile distribution class.
Computes the quantiles of q-value for each action. Quantile distribution function models the quantiles of q value for each action by dividing the probability (which is between 0.0 and 1.0) into ‘n_quantile’ number of bins and assigning the n-quantile to n-th bin.
- Parameters:
scope_name (str) – scope name of the model
n_quantile (int) – Number of bins.
- all_quantiles(s: Variable) Variable [source]¶
Computes the quantiles of q-value for each action for the given state.
- Parameters:
s (nn.Variable) – state variable
- Returns:
quantiles of q-value for each action for the given state
- Return type:
nn.Variable
- as_q_function() QFunction [source]¶
Convert the quantile distribution function to QFunction.
- Returns:
QFunction instance which computes the q-values based on the quantiles.
- Return type:
- class nnabla_rl.models.StateActionQuantileFunction(scope_name: str, n_action: int, K: int, risk_measure_function: ~typing.Callable[[~nnabla._variable.Variable], ~nnabla._variable.Variable] = <function risk_neutral_measure>)[source]¶
Bases:
Model
State-action quantile function class.
Computes the return samples of q-value for each action. State-action quantile function computes the return samples of q value for each action using sampled quantile threshold (e.g. \(\tau\sim U([0,1])\)) for given state.
- Parameters:
scope_name (str) – scope name of the model
n_action (int) – Number of actions which used in target environment.
K (int) – Number of samples for quantile threshold \(\tau\).
risk_measure_function (Callable[[nn.Variable], nn.Variable]) – Risk measure funciton which modifies the weightings of tau. Defaults to risk neutral measure which does not do any change to the taus.
- all_quantile_values(s: Variable, tau: Variable) Variable [source]¶
Compute the return samples for all action for given state and quantile threshold.
- Parameters:
s (nn.Variable) – state variable.
tau (nn.Variable) – quantile threshold.
- Returns:
return samples from implicit return distribution for given state using tau.
- Return type:
nn.Variable
- as_q_function() QFunction [source]¶
Convert the state action quantile function to QFunction.
- Returns:
QFunction instance which computes the q-values based on return samples.
- Return type:
- max_q_quantile_values(s: Variable, tau: Variable) Variable [source]¶
Compute the return samples from distribution that maximizes q value for given state using quantile threshold.
- Parameters:
s (nn.Variable) – state variable.
tau (nn.Variable) – quantile threshold.
- Returns:
return samples from implicit return distribution that maximizes q for given state using tau.
- Return type:
nn.Variable
- quantile_values(s: Variable, a: Variable, tau: Variable) Variable [source]¶
Compute the return samples for given state and action.
- Parameters:
s (nn.Variable) – state variable.
a (nn.Variable) – action variable.
tau (nn.Variable) – quantile threshold.
- Returns:
return samples from implicit return distribution for given state and action using tau.
- Return type:
nn.Variable
- class nnabla_rl.models.reward_function.RewardFunction(scope_name: str)[source]¶
Bases:
Model
Base reward function class.
- abstract r(s_current: Variable, a_current: Variable, s_next: Variable) Variable [source]¶
R Computes the reward for the given state, action and next state. One (or more than one) of the input variables may not be used in the actual computation.
- Parameters:
s_current (nnabla.Variable) – State variable
a_current (nnabla.Variable) – Action variable
s_next (nnabla.Variable) – Next state variable
- Returns:
Reward for the given state, action and next state.
- Return type:
nnabla.Variable
- class nnabla_rl.models.VariationalAutoEncoder(scope_name: str)[source]¶
Bases:
Encoder
- abstract decode(z: Variable | None, **kwargs) Variable [source]¶
Reconstruct the latent representation.
- Parameters:
z (nn.Variable, optional) – latent variable. If the input is None, random sample will be used instead.
- Returns:
reconstructed variable
- Return type:
nn.Variable
- abstract decode_multiple(z: Variable | None, decode_num: int, **kwargs)[source]¶
Reconstruct multiple latent representations.
- Parameters:
z (nn.Variable, optional) – encoder input. If the input is None, random sample will be used instead.
- Returns:
Reconstructed input and latent distribution
- Return type:
nn.Variable
- abstract encode_and_decode(x: Variable, **kwargs) Tuple[Distribution, Variable] [source]¶
Encode the input variable and reconstruct.
- Parameters:
x (nn.Variable) – encoder input.
- Returns:
latent distribution and reconstructed input
- Return type:
Tuple[Distribution, nn.Variable]
- abstract latent_distribution(x: Variable, **kwargs) Distribution [source]¶
Compute the latent distribution \(p(z|x)\).
- Parameters:
x (nn.Variable) – encoder input.
- Returns:
latent distribution
- Return type: