Source code for nnabla_rl.models.model

# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import pathlib
from typing import Dict, Optional, Tuple, TypeVar, Union

import nnabla as nn
from nnabla_rl.logger import logger

T = TypeVar('T', bound='Model')


[docs]class Model(object): """Model Class. Args: scope_name (str): the scope name of model """ # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _scope_name: str def __init__(self, scope_name: str): self._scope_name = scope_name @property def scope_name(self) -> str: """scope_name Get scope name of this model. Returns: scope_name (str): scope name of the model """ return self._scope_name
[docs] def get_parameters(self, grad_only: bool = True) -> Dict[str, nn.Variable]: """Get_parameters Retrive parameters associated with this model. Args: grad_only (bool): Retrive parameters only with need_grad = True. Defaults to True. Returns: parameters (OrderedDict): Parameter map. """ with nn.parameter_scope(self.scope_name): parameters: Dict[str, nn.Variable] = nn.get_parameters(grad_only=grad_only) return parameters
[docs] def clear_parameters(self): """clear_parameters Clear all parameters associated with this model.""" with nn.parameter_scope(self.scope_name): parameters: Dict[str, nn.Variable] = nn.clear_parameters() return parameters
[docs] def is_recurrent(self) -> bool: """is_recurrent Check whether the model uses recurrent network component or not. Model which use LSTM, GRU and/or any other recurrent network component must return True. Returns: bool: True if the model uses recurrent network component. Otherwise False. """ return False
[docs] def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: """Internal_state_shapes Return internal state shape as tuple of ints for each internal state (excluding the batch_size). This method will be called by (:py:class:`RNNModelTrainer. <nnabla_rl.model_trainers.model_trainer.RNNModelTrainer>`) and its subclasses to setup training variables. Model which use LSTM, GRU and/or any other recurrent network component must implement this method. Returns: Dict[str, Tuple[int, ...]]: internal state shapes. key is the name of each internal state. """ raise NotImplementedError
[docs] def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): """set_internal states Set the internal state variable of rnn cell to given state. Model which use LSTM, GRU and/or any other recurrent network component must implement this method. Args: states (None or Dict[str, nn.Variable]): If None, reset all internal state to zero. If state is provided, set the provided state as internal state. """ raise NotImplementedError
[docs] def reset_internal_states(self): """reset_internal states Set the internal state variable of rnn cell to given zero.""" self.set_internal_states(None)
[docs] def get_internal_states(self) -> Dict[str, nn.Variable]: """get_internal states Get the internal state variable of rnn cell. Model which use LSTM, GRU and/or any other recurrent network component must implement this method. Returns: Dict[str, nn.Variable]: Value of each internal state. key is the name of each internal state. """ raise NotImplementedError
[docs] def save_parameters(self, filepath: Union[str, pathlib.Path]) -> None: """save_parameters Save model parameters to given filepath. Args: filepath (str or pathlib.Path): paramter file path """ if isinstance(filepath, pathlib.Path): filepath = str(filepath) with nn.parameter_scope(self.scope_name): nn.save_parameters(path=filepath)
[docs] def load_parameters(self, filepath: Union[str, pathlib.Path]) -> None: """load_parameters Load model parameters from given filepath. Args: filepath (str or pathlib.Path): paramter file path """ if isinstance(filepath, pathlib.Path): filepath = str(filepath) with nn.parameter_scope(self.scope_name): nn.load_parameters(path=filepath)
[docs] def deepcopy(self: T, new_scope_name: str) -> T: """Deepcopy Create a (deep) copy of the model. All the model parameter's (if exist) associated with will be copied and new_scope_name will be assigned. Args: new_scope_name (str): scope_name of parameters for newly created model Returns: Model: copied model Raises: ValueError: Given scope name is same as the model or already exists. """ assert new_scope_name != self._scope_name, 'Can not use same scope_name!' copied = copy.deepcopy(self) copied._scope_name = new_scope_name # copy current parameter if is already created params = self.get_parameters(grad_only=False) with nn.parameter_scope(new_scope_name): for param_name, param in params.items(): if nn.parameter.get_parameter(param_name) is not None: raise RuntimeError(f'Model with scope_name: {new_scope_name} already exists!!') logger.info( f'copying param with name: {self.scope_name}/{param_name} ---> {new_scope_name}/{param_name}') nn.parameter.get_parameter_or_create(param_name, shape=param.shape, initializer=param.d) return copied
[docs] def shallowcopy(self: T) -> T: """Shallowcopy Create a (shallow) copy of the model. Unlike deepcopy, shallowcopy will KEEP sharing the original network parameter by using same scope_name as original model. However, all the class members will be (deep) copied to the new instance. Do NOT use this method unless you understand what this method does. Returns: Model: (shallow) copied model """ copied = copy.deepcopy(self) return copied