Source code for nnabla_rl.algorithm

# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar, Union, cast

import gym
import numpy as np

import nnabla as nn
import nnabla_rl as rl
from nnabla_rl.configuration import Configuration
from nnabla_rl.environments.environment_info import EnvironmentInfo
from nnabla_rl.exceptions import UnsupportedEnvironmentException, UnsupportedTrainingException
from nnabla_rl.hook import Hook
from nnabla_rl.logger import logger
from nnabla_rl.model_trainers.model_trainer import ModelTrainer
from nnabla_rl.replay_buffer import ReplayBuffer

F = TypeVar('F', bound=Callable[..., Any])


def eval_api(f: F) -> F:
    def wrapped_with_eval_scope(*args, **kwargs):
        with rl.eval_scope():
            return f(*args, **kwargs)
    return cast(F, wrapped_with_eval_scope)


[docs]@dataclass class AlgorithmConfig(Configuration): """List of algorithm common configuration. Args: gpu_id (int): id of the gpu to use. If negative, the training will run on cpu. Defaults to -1. """ gpu_id: int = -1
[docs]class Algorithm(metaclass=ABCMeta): """Base Algorithm class. Args: env_or_env_info\ (gym.Env or :py:class:`EnvironmentInfo <nnabla_rl.environments.environment_info.EnvironmentInfo>`) : environment or environment info config (:py:class:`AlgorithmConfig <nnabla_rl.algorithm.AlgorithmConfig>`): configuration of the algorithm Note: Default functions, solvers and configurations are set to the configurations of each algorithm's original paper. Default functions may not work depending on the environment. """ # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _env_info: EnvironmentInfo _config: AlgorithmConfig _iteration_num: int _hooks: Sequence[Hook] def __init__(self, env_info, config=AlgorithmConfig()): if isinstance(env_info, gym.Env): env_info = EnvironmentInfo.from_env(env_info) self._env_info = env_info self._config = config self._iteration_num = 0 self._hooks = [] if not self.is_supported_env(env_info): raise UnsupportedEnvironmentException("{} does not support the enviroment. \ See the algorithm catalog (https://github.com/sony/nnabla-rl/tree/master/nnabla_rl/algorithms) \ and confirm what kinds of enviroments are supported".format(self.__name__)) if self._config.gpu_id < 0: logger.info('algorithm will run on cpu.') else: logger.info('algorithm will run on gpu: {}'.format(self._config.gpu_id)) @property def __name__(self): return self.__class__.__name__ @property def latest_iteration_state(self) -> Dict[str, Any]: """Return latest iteration state that is composed of items of training process state. You can use this state for debugging (e.g. plot loss curve). See [IterationStateHook](./hooks/iteration_state_hook.py) for getting more details. Returns: Dict[str, Any]: Dictionary with items of training process state. """ latest_iteration_state: Dict[str, Any] = {} latest_iteration_state['scalar'] = {} latest_iteration_state['histogram'] = {} latest_iteration_state['image'] = {} return latest_iteration_state @property def iteration_num(self) -> int: """Current iteration number. Returns: int: Current iteration number of running training. """ return self._iteration_num
[docs] def train(self, env_or_buffer: Union[gym.Env, ReplayBuffer], total_iterations: int = sys.maxsize): """Train the policy with reinforcement learning algorithm. Args: env_or_buffer (Union[gym.Env, ReplayBuffer]): Target environment to train the policy online or reply buffer to train the policy offline. total_iterations (int): Total number of iterations to train the policy. Raises: UnsupportedTrainingException: Raises if this algorithm does not support the training method for given parameter. """ if self._is_env(env_or_buffer): env_or_buffer = cast(gym.Env, env_or_buffer) self.train_online(env_or_buffer, total_iterations) elif self._is_buffer(env_or_buffer): env_or_buffer = cast(ReplayBuffer, env_or_buffer) self.train_offline(env_or_buffer, total_iterations) else: raise UnsupportedTrainingException
[docs] def train_online(self, train_env: gym.Env, total_iterations: int = sys.maxsize): """Train the policy by interacting with given environment. Args: train_env (gym.Env): Target environment to train the policy. total_iterations (int): Total number of iterations to train the policy. Raises: UnsupportedTrainingException: Raises if the algorithm does not support online training """ if self._has_rnn_models(): self._assert_rnn_is_supported() self._before_training_start(train_env) self._setup_hooks(total_iterations) for _ in range(total_iterations): self._iteration_num += 1 self._run_online_training_iteration(train_env) self._invoke_hooks() self._teardown_hooks(total_iterations) self._after_training_finish(train_env)
[docs] def train_offline(self, replay_buffer: ReplayBuffer, total_iterations: int = sys.maxsize): """Train the policy using only the replay buffer. Args: replay_buffer (ReplayBuffer): Replay buffer to sample experiences to train the policy. total_iterations (int): Total number of iterations to train the policy. Raises: UnsupportedTrainingException: Raises if the algorithm does not support offline training """ if self._has_rnn_models(): self._assert_rnn_is_supported() self._before_training_start(replay_buffer) self._setup_hooks(total_iterations) for _ in range(total_iterations): self._iteration_num += 1 self._run_offline_training_iteration(replay_buffer) self._invoke_hooks() self._teardown_hooks(total_iterations) self._after_training_finish(replay_buffer)
[docs] def set_hooks(self, hooks: Sequence[Hook]): """Set hooks for running additional operation during training. Previously set hooks will be removed and replaced with new hooks. Args: hooks (list of nnabla_rl.hook.Hook): Hooks to invoke during training """ self._hooks = hooks
def _invoke_hooks(self): for hook in self._hooks: hook(self) def _setup_hooks(self, total_iterations: int): for hook in self._hooks: hook.setup(self, total_iterations) def _teardown_hooks(self, total_iterations: int): for hook in self._hooks: hook.teardown(self, total_iterations)
[docs] @abstractmethod def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}) -> np.ndarray: """Compute action for given state using current best policy. This is usually used for evaluation. Args: state (np.ndarray): state to compute the action. begin_of_episode (bool): Used for rnn state resetting. This flag informs the beginning of episode. extra_info (Dict[str, Any]): Dictionary to provide extra information to compute the action. Most of the algorithm does not use this field. Returns: np.ndarray: Action for given state using current trained policy. """ raise NotImplementedError
[docs] def compute_trajectory(self, initial_trajectory: Sequence[Tuple[np.ndarray, Optional[np.ndarray]]]) \ -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]]]: """Compute trajectory (sequence of state and action tuples) from given initial trajectory using current policy. Most of the reinforcement learning algorithms does not implement this method. Only the optimal control algorithms implements this method. Args: initial_trajectory (Sequence[Tuple[np.ndarray, Optional[np.ndarray]]]): initial trajectory. Returns: Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]]]: Sequence of state and action tuples and extra information (if exist) at each timestep, computed with current best policy. Extra information depends on the algorithm. The sequence length is same as the length of initial trajectory. """ raise NotImplementedError
def _before_training_start(self, env_or_buffer): pass @abstractmethod def _run_online_training_iteration(self, env): raise NotImplementedError @abstractmethod def _run_offline_training_iteration(self, buffer): raise NotImplementedError def _after_training_finish(self, env_or_buffer): pass @abstractmethod def _models(self): """Model objects which are trained by the algorithm. Returns: Dict[str, nnabla_rl.model.Model]: Dictionary with items of model name as key and object as value. """ raise NotImplementedError @abstractmethod def _solvers(self) -> Dict[str, nn.solver.Solver]: """Solver objects which are used for training the models by the algorithm. Returns: Dict[str, nn.solver.Solver]: Dictionary with items of solver name as key and object as value. """ raise NotImplementedError def _is_env(self, env): return isinstance(env, gym.Env) def _is_buffer(self, env): return isinstance(env, ReplayBuffer) def _has_rnn_models(self): for model in self._models().values(): if model.is_recurrent(): return True return False def _assert_rnn_is_supported(self): if not self.is_rnn_supported(): raise RuntimeError(f'{self.__name__} does not support rnn models but rnn models where given!')
[docs] @classmethod @abstractmethod def is_supported_env(cls, env_or_env_info: Union[gym.Env, EnvironmentInfo]) -> bool: """Check whether the algorithm supports the enviroment or not. Args: env_or_env_info \ (gym.Env or :py:class:`EnvironmentInfo <nnabla_rl.environments.environment_info.EnvironmentInfo>`) \ : environment or environment info Returns: bool: True if the algorithm supports the environment. Otherwise False. """ raise NotImplementedError
[docs] @classmethod def is_rnn_supported(cls) -> bool: """Check whether the algorithm supports rnn models or not. Returns: bool: True if the algorithm supports rnn models. Otherwise False. """ return False
@property @abstractmethod def trainers(self) -> Dict[str, ModelTrainer]: raise NotImplementedError