Source code for nnabla_rl.algorithms.srsac

# Copyright 2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from dataclasses import dataclass
from typing import Sequence, Union

import gym

import nnabla as nn
import nnabla_rl.model_trainers as MT
from nnabla_rl.algorithms.sac import (SAC, DefaultExplorerBuilder, DefaultPolicyBuilder, DefaultQFunctionBuilder,
                                      DefaultReplayBufferBuilder, DefaultSolverBuilder, SACConfig)
from nnabla_rl.builders import ExplorerBuilder, ModelBuilder, ReplayBufferBuilder, SolverBuilder
from nnabla_rl.environments.environment_info import EnvironmentInfo
from nnabla_rl.model_trainers.model_trainer import TrainingBatch
from nnabla_rl.models import Model, QFunction, StochasticPolicy
from nnabla_rl.utils import context
from nnabla_rl.utils.data import marshal_experiences


[docs]@dataclass class SRSACConfig(SACConfig): """SRSACConfig List of configurations for SRSAC algorithm. Args: gamma (float): discount factor of rewards. Defaults to 0.99. learning_rate (float): learning rate which is set to all solvers. \ You can customize/override the learning rate for each solver by implementing the \ (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`) by yourself. \ Defaults to 0.0003. batch_size(int): training batch size. Defaults to 256. tau (float): target network's parameter update coefficient. Defaults to 0.005. environment_steps (int): Number of steps to interact with the environment on each iteration. Defaults to 1. gradient_steps (int): Number of parameter updates to perform on each iteration. Defaults to 1. \ Keep this value to 1 and use replay_ratio to control the number of updates in SRSAC. target_entropy (float, optional): Target entropy value. Defaults to None. initial_temperature (float, optional): Initial value of temperature parameter. Defaults to None. fix_temperature (bool): If true the temperature parameter will not be trained. Defaults to False. start_timesteps (int): the timestep when training starts.\ The algorithm will collect experiences from the environment by acting randomly until this timestep.\ Defaults to 10000. replay_buffer_size (int): capacity of the replay buffer. Defaults to 1000000. num_steps (int): number of steps for N-step Q targets. Defaults to 1. actor_unroll_steps (int): Number of steps to unroll actor's tranining network.\ The network will be unrolled even though the provided model doesn't have RNN layers.\ Defaults to 1. actor_burn_in_steps (int): Number of burn-in steps to initiaze actor's recurrent layer states during training.\ This flag does not take effect if given model is not an RNN model.\ Defaults to 0. actor_reset_rnn_on_terminal (bool): Reset actor's recurrent internal states to zero during training\ if episode ends. This flag does not take effect if given model is not an RNN model.\ Defaults to False. critic_unroll_steps (int): Number of steps to unroll critic's tranining network.\ The network will be unrolled even though the provided model doesn't have RNN layers.\ Defaults to 1. critic_burn_in_steps (int): Number of burn-in steps to initiaze critic's recurrent layer states\ during training. This flag does not take effect if given model is not an RNN model.\ Defaults to 0. critic_reset_rnn_on_terminal (bool): Reset critic's recurrent internal states to zero during training\ if episode ends. This flag does not take effect if given model is not an RNN model.\ Defaults to False. replay_ratio (int): Number of updates per environment step. Defaults to 1. reset_interval (int): Paramerters will be reset every this number of updates. Defaults to 2560000. """ replay_ratio: int = 1 reset_interval: int = 2560000 # 2.56 * 10^6 def __post_init__(self): super().__post_init__() self._assert_positive(self.replay_ratio, 'replay_ratio') self._assert_positive(self.reset_interval, 'reset_interval')
[docs]class SRSAC(SAC): """Scaled-by-Resetting Soft Actor-Critic (SRSAC) algorithm implementation. This class implements Scaled-by-Restting Soft Actor Critic (SRSAC) algorithm proposed by P. D'Oro, et al. in the paper: "Sample-Efficient Reinforcement Learning by Breaking the Replay Ratio Barrier". For details see: https://openreview.net/forum?id=OpC-9aBBVJe This algorithm periodically resets the models and optimizers' parameters for stable and efficient learning. Args: env_or_env_info \ (gym.Env or :py:class:`EnvironmentInfo <nnabla_rl.environments.environment_info.EnvironmentInfo>`): the environment to train or environment info config (:py:class:`SRSACConfig <nnabla_rl.algorithms.sacd.SRSACConfig>`): configuration of the SRSAC algorithm q_function_builder (:py:class:`ModelBuilder[QFunction] <nnabla_rl.builders.ModelBuilder>`): builder of q function models q_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`): builder of q function solvers policy_builder (:py:class:`ModelBuilder[StochasticPolicy] <nnabla_rl.builders.ModelBuilder>`): builder of actor models policy_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`): builder of policy solvers temperature_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`): builder of temperature solvers replay_buffer_builder (:py:class:`ReplayBufferBuilder <nnabla_rl.builders.ReplayBufferBuilder>`): builder of replay_buffer explorer_builder (:py:class:`ExplorerBuilder <nnabla_rl.builders.ExplorerBuilder>`): builder of environment explorer """ # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: SRSACConfig def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], config: SRSACConfig = SRSACConfig(), q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), q_solver_builder: SolverBuilder = DefaultSolverBuilder(), policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), temperature_solver_builder: SolverBuilder = DefaultSolverBuilder(), replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): super(SRSAC, self).__init__( env_or_env_info=env_or_env_info, config=config, q_function_builder=q_function_builder, q_solver_builder=q_solver_builder, policy_builder=policy_builder, policy_solver_builder=policy_solver_builder, temperature_solver_builder=temperature_solver_builder, replay_buffer_builder=replay_buffer_builder, explorer_builder=explorer_builder, ) def _run_online_training_iteration(self, env): for _ in range(self._config.environment_steps): self._run_environment_step(env) for _ in range(self._config.gradient_steps): if self._config.start_timesteps < self.iteration_num: self._run_gradient_step(self._replay_buffer) def _run_offline_training_iteration(self, buffer): self._run_gradient_step(buffer) def _run_gradient_step(self, replay_buffer): for _ in range(self._config.replay_ratio): self._sac_training(replay_buffer) num_updates = (self.iteration_num * self._config.replay_ratio) % self._config.reset_interval num_updates += self._config.replay_ratio if self._config.reset_interval <= num_updates: self._reset_model_parameters(self._models().values()) self._reconstruct_training_graphs() self._reconstruct_actors() def _reset_model_parameters(self, models: Sequence[Model]): solvers = self._solvers() for model in models: model.clear_parameters() solver: nn.solvers.Solver = solvers[model.scope_name] solver.clear_parameters() def _reconstruct_training_graphs(self): self._temperature = self._setup_temperature_model() self._policy_trainer = self._setup_policy_training(env_or_buffer=None) self._q_function_trainer = self._setup_q_function_training(env_or_buffer=None) def _reconstruct_actors(self): self._evaluation_actor = self._setup_evaluation_actor() self._exploration_actor = self._setup_exploration_actor()
[docs]@dataclass class EfficientSRSACConfig(SRSACConfig): """EfficientSRSACConfig List of configurations for EfficientSRSAC algorithm. Args: gamma (float): discount factor of rewards. Defaults to 0.99. learning_rate (float): learning rate which is set to all solvers. \ You can customize/override the learning rate for each solver by implementing the \ (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`) by yourself. \ Defaults to 0.0003. batch_size(int): training batch size. Defaults to 256. tau (float): target network's parameter update coefficient. Defaults to 0.005. environment_steps (int): Number of steps to interact with the environment on each iteration. Defaults to 1. gradient_steps (int): Number of parameter updates to perform on each iteration. Defaults to 1. \ Keep this value to 1 and use replay_ratio to control the number of updates in SRSAC. target_entropy (float, optional): Target entropy value. Defaults to None. initial_temperature (float, optional): Initial value of temperature parameter. Defaults to None. fix_temperature (bool): If true the temperature parameter will not be trained. Defaults to False. start_timesteps (int): the timestep when training starts.\ The algorithm will collect experiences from the environment by acting randomly until this timestep.\ Defaults to 10000. replay_buffer_size (int): capacity of the replay buffer. Defaults to 1000000. num_steps (int): Not supported. This configuration does not take effect in the training. actor_unroll_steps (int): Not supported. This configuration does not take effect in the training. actor_burn_in_steps (int): Not supported. This configuration does not take effect in the training. actor_reset_rnn_on_terminal (bool): Not supported. This configuration does not take effect in the training. critic_unroll_steps (int): Not supported. This configuration does not take effect in the training. critic_burn_in_steps (int): Not supported. This configuration does not take effect in the training. critic_reset_rnn_on_terminal (bool): Not supported. This configuration does not take effect in the training. replay_ratio (int): Number of updates per environment step. reset_interval (int): Paramerters will be reset every this number of updates. """ actor_reset_rnn_on_terminal: bool = False critic_reset_rnn_on_terminal: bool = False def __post_init__(self): super().__post_init__() def fill_warning_message(config_name, config_value, expected_value): return f'''{config_name} is set to {config_value}(!={expected_value}) but this value does not take any effect on EfficentSRSAC.''' if 1 != self.num_steps: warnings.warn(fill_warning_message('num_steps', self.num_steps, 1)) if 0 != self.actor_burn_in_steps: warnings.warn(fill_warning_message('actor_burn_in_steps', self.actor_burn_in_steps, 0)) if 1 != self.actor_unroll_steps: warnings.warn(fill_warning_message('actor_unroll_steps', self.actor_unroll_steps, 1)) if self.actor_reset_rnn_on_terminal: warnings.warn(fill_warning_message('actor_reset_rnn_on_terminal', self.actor_reset_rnn_on_terminal, False)) if 0 != self.critic_burn_in_steps: warnings.warn(fill_warning_message('critic_burn_in_steps', self.critic_burn_in_steps, 0)) if 1 != self.critic_unroll_steps: warnings.warn(fill_warning_message('critic_unroll_steps', self.critic_unroll_steps, 1)) if self.critic_reset_rnn_on_terminal: warnings.warn(fill_warning_message('critic_reset_rnn_on_terminal', self.critic_reset_rnn_on_terminal, False))
[docs]class EfficientSRSAC(SRSAC): """Efficient implementation of Scaled-by-Resetting Soft Actor-Critic (SRSAC) algorithm. This class implements a computationally efficient version of Scaled-by-Restting Soft Actor Critic (SRSAC) algorithm proposed by P. D'Oro, et al. in the paper: "Sample-Efficient Reinforcement Learning by Breaking the Replay Ratio Barrier". For details see: https://openreview.net/forum?id=OpC-9aBBVJe This implementation does not support recurrent networks. For recurrent network support use SRSAC class. Args: env_or_env_info \ (gym.Env or :py:class:`EnvironmentInfo <nnabla_rl.environments.environment_info.EnvironmentInfo>`): the environment to train or environment info config (:py:class:`SRSACConfig <nnabla_rl.algorithms.sacd.SRSACConfig>`): configuration of the SRSAC algorithm q_function_builder (:py:class:`ModelBuilder[QFunction] <nnabla_rl.builders.ModelBuilder>`): builder of q function models q_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`): builder of q function solvers policy_builder (:py:class:`ModelBuilder[StochasticPolicy] <nnabla_rl.builders.ModelBuilder>`): builder of actor models policy_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`): builder of policy solvers temperature_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`): builder of temperature solvers replay_buffer_builder (:py:class:`ReplayBufferBuilder <nnabla_rl.builders.ReplayBufferBuilder>`): builder of replay_buffer explorer_builder (:py:class:`ExplorerBuilder <nnabla_rl.builders.ExplorerBuilder>`): builder of environment explorer """ # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: EfficientSRSACConfig def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], config: EfficientSRSACConfig = EfficientSRSACConfig(), q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), q_solver_builder: SolverBuilder = DefaultSolverBuilder(), policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), temperature_solver_builder: SolverBuilder = DefaultSolverBuilder(), replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): super().__init__(env_or_env_info=env_or_env_info, config=config, q_function_builder=q_function_builder, q_solver_builder=q_solver_builder, policy_builder=policy_builder, policy_solver_builder=policy_solver_builder, temperature_solver_builder=temperature_solver_builder, replay_buffer_builder=replay_buffer_builder, explorer_builder=explorer_builder)
[docs] @classmethod def is_rnn_supported(cls): return False
def _run_offline_training_iteration(self, buffer): self._run_gradient_step(buffer) def _before_training_start(self, env_or_buffer): # set context globally to ensure that the training runs on configured gpu context.set_nnabla_context(self._config.gpu_id) self._environment_explorer = self._setup_environment_explorer(env_or_buffer) self._actor_critic_trainer = self._setup_actor_critic_training(env_or_buffer) def _setup_actor_critic_training(self, env_or_buffer): actor_critic_trainer_config = MT.hybrid_trainers.SRSACActorCriticTrainerConfig( fixed_temperature=self._config.fix_temperature, target_entropy=self._config.target_entropy, replay_ratio=self._config.replay_ratio, tau=self._config.tau) actor_critic_trainer = MT.hybrid_trainers.SRSACActorCriticTrainer( pi=self._pi, pi_solvers={self._pi.scope_name: self._pi_solver}, q_functions=self._train_q_functions, q_solvers=self._train_q_solvers, target_q_functions=self._target_q_functions, temperature=self._temperature, temperature_solver=self._temperature_solver, env_info=self._env_info, config=actor_critic_trainer_config) return actor_critic_trainer def _run_gradient_step(self, replay_buffer): self._efficient_srsac_training(replay_buffer) num_updates = (self.iteration_num * self._config.replay_ratio) % self._config.reset_interval num_updates += self._config.replay_ratio if self._config.reset_interval <= num_updates: self._reset_model_parameters(self._models().values()) self._reconstruct_training_graphs() self._reconstruct_actors() def _efficient_srsac_training(self, replay_buffer): num_steps = self._config.replay_ratio experiences_tuple = [] info_tuple = [] for _ in range(num_steps): experiences, info = replay_buffer.sample(self._config.batch_size) experiences_tuple.append(experiences) info_tuple.append(info) assert len(experiences_tuple) == num_steps batch = None for experiences, info in zip(experiences_tuple, info_tuple): (s, a, r, non_terminal, s_next, rnn_states_dict, *_) = marshal_experiences(experiences) rnn_states = rnn_states_dict['rnn_states'] if 'rnn_states' in rnn_states_dict else {} batch = TrainingBatch(batch_size=self._config.batch_size, s_current=s, a_current=a, gamma=self._config.gamma, reward=r, non_terminal=non_terminal, s_next=s_next, weight=info['weights'], next_step_batch=batch, rnn_states=rnn_states) self._actor_critic_trainer_state = self._actor_critic_trainer.train(batch) td_errors = self._actor_critic_trainer_state['td_errors'] replay_buffer.update_priorities(td_errors) def _reconstruct_training_graphs(self): self._temperature = self._setup_temperature_model() self._actor_critic_trainer = self._setup_actor_critic_training(env_or_buffer=None) @property def latest_iteration_state(self): latest_iteration_state = super(SAC, self).latest_iteration_state if hasattr(self, '_actor_critic_trainer_state'): latest_iteration_state['scalar'].update({'pi_loss': float(self._actor_critic_trainer_state['pi_loss'])}) if hasattr(self, '_actor_critic_trainer_state'): latest_iteration_state['scalar'].update({'q_loss': float(self._actor_critic_trainer_state['q_loss'])}) latest_iteration_state['histogram'].update( {'td_errors': self._actor_critic_trainer_state['td_errors'].flatten()}) return latest_iteration_state