Source code for nnabla_rl.algorithms.bcq

# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Any, Dict, List, Union

import gym

import nnabla as nn
import nnabla.solvers as NS
import nnabla_rl.functions as RF
import nnabla_rl.model_trainers as MT
from nnabla_rl.algorithm import Algorithm, AlgorithmConfig, eval_api
from nnabla_rl.algorithms.common_utils import has_batch_dimension
from nnabla_rl.builders import ModelBuilder, SolverBuilder
from nnabla_rl.environments.environment_info import EnvironmentInfo
from nnabla_rl.model_trainers.model_trainer import ModelTrainer, TrainingBatch
from nnabla_rl.models import (BCQPerturbator, BCQVariationalAutoEncoder, DeterministicPolicy, Perturbator, QFunction,
                              TD3QFunction, VariationalAutoEncoder)
from nnabla_rl.utils import context
from nnabla_rl.utils.data import add_batch_dimension, marshal_experiences, set_data_to_variable
from nnabla_rl.utils.misc import create_variable, sync_model


[docs]@dataclass class BCQConfig(AlgorithmConfig): """BCQConfig List of configurations for BCQ algorithm. Args: gamma (float): discount factor of reward. Defaults to 0.99. learning_rate (float): learning rate which is set to all solvers. \ You can customize/override the learning rate for each solver by implementing the \ (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`) by yourself. \ Defaults to 0.001. batch_size (int): training batch size. Defaults to 100. tau (float): target network's parameter update coefficient. Defaults to 0.005. lmb (float): weight :math:`\\lambda` used for balancing the ratio between :math:`\\min{Q}` and :math:`\\max{Q}`\ on target q value generation (i.e. :math:`\\lambda\\min{Q} + (1 - \\lambda)\\max{Q}`).\ Defaults to 0.75. phi (float): action perturbator noise coefficient. Defaults to 0.05. num_q_ensembles (int): number of q function ensembles . Defaults to 2. num_action_samples (int): number of actions to sample for computing target q values. Defaults to 10. """ gamma: float = 0.99 learning_rate: float = 1.0*1e-3 batch_size: int = 100 tau: float = 0.005 lmb: float = 0.75 phi: float = 0.05 num_q_ensembles: int = 2 num_action_samples: int = 10 def __post_init__(self): """__post_init__ Check set values are in valid range. """ self._assert_between(self.tau, 0.0, 1.0, 'tau') self._assert_between(self.gamma, 0.0, 1.0, 'gamma') self._assert_positive(self.lmb, 'lmb') self._assert_positive(self.phi, 'phi') self._assert_positive(self.num_q_ensembles, 'num_q_ensembles') self._assert_positive(self.num_action_samples, 'num_action_samples') self._assert_positive(self.batch_size, 'batch_size')
class DefaultQFunctionBuilder(ModelBuilder[QFunction]): def build_model(self, # type: ignore[override] scope_name: str, env_info: EnvironmentInfo, algorithm_config: BCQConfig, **kwargs) -> QFunction: return TD3QFunction(scope_name) class DefaultVAEBuilder(ModelBuilder[VariationalAutoEncoder]): def build_model(self, # type: ignore[override] scope_name: str, env_info: EnvironmentInfo, algorithm_config: BCQConfig, **kwargs) -> VariationalAutoEncoder: max_action_value = float(env_info.action_high[0]) return BCQVariationalAutoEncoder(scope_name, env_info.state_dim, env_info.action_dim, env_info.action_dim*2, max_action_value) class DefaultPerturbatorBuilder(ModelBuilder[Perturbator]): def build_model(self, # type: ignore[override] scope_name: str, env_info: EnvironmentInfo, algorithm_config: BCQConfig, **kwargs) -> Perturbator: max_action_value = float(env_info.action_high[0]) return BCQPerturbator(scope_name, env_info.state_dim, env_info.action_dim, max_action_value) class DefaultSolverBuilder(SolverBuilder): def build_solver(self, # type: ignore[override] env_info: EnvironmentInfo, algorithm_config: BCQConfig, **kwargs): return NS.Adam(alpha=algorithm_config.learning_rate)
[docs]class BCQ(Algorithm): """Batch-Constrained Q-learning (BCQ) algorithm. This class implements the Batch-Constrained Q-learning (BCQ) algorithm proposed by S. Fujimoto, et al. in the paper: "Off-Policy Deep Reinforcement Learning without Exploration" For details see: https://arxiv.org/abs/1812.02900 This algorithm only supports offline training. Args: env_or_env_info \ (gym.Env or :py:class:`EnvironmentInfo <nnabla_rl.environments.environment_info.EnvironmentInfo>`): the environment to train or environment info config (:py:class:`BCQConfig <nnabla_rl.algorithms.bcq.BCQConfig>`): configuration of the BCQ algorithm q_function_builder (:py:class:`ModelBuilder[QFunction] <nnabla_rl.builders.ModelBuilder>`): builder of q-function models q_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`): builder for q-function solvers vae_builder (:py:class:`ModelBuilder[VariationalAutoEncoder] <nnabla_rl.builders.ModelBuilder>`): builder of variational auto encoder models vae_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`): builder for variational auto encoder solvers perturbator_builder (:py:class:`PerturbatorBuilder <nnabla_rl.builders.PerturbatorBuilder>`): builder of perturbator models perturbator_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`): builder for perturbator solvers """ # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: BCQConfig _q_ensembles: List[QFunction] _q_solvers: Dict[str, nn.solver.Solver] _target_q_ensembles: List[QFunction] _vae: VariationalAutoEncoder _vae_solver: nn.solver.Solver _xi: Perturbator _xi_solver: nn.solver.Solver _q_function_trainer: ModelTrainer _encoder_trainer: ModelTrainer _perturbator_trainer: ModelTrainer _eval_state_var: nn.Variable _eval_action: nn.Variable _eval_max_index: nn.Variable _encoder_trainer_state: Dict[str, Any] _q_function_trainer_state: Dict[str, Any] _perturbator_trainer_state: Dict[str, Any] def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], config: BCQConfig = BCQConfig(), q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), q_solver_builder: SolverBuilder = DefaultSolverBuilder(), vae_builder: ModelBuilder[VariationalAutoEncoder] = DefaultVAEBuilder(), vae_solver_builder: SolverBuilder = DefaultSolverBuilder(), perturbator_builder: ModelBuilder[Perturbator] = DefaultPerturbatorBuilder(), perturbator_solver_builder: SolverBuilder = DefaultSolverBuilder()): super(BCQ, self).__init__(env_or_env_info, config=config) with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): self._q_ensembles = [] self._q_solvers = {} self._target_q_ensembles = [] for i in range(self._config.num_q_ensembles): q = q_function_builder(scope_name=f"q{i}", env_info=self._env_info, algorithm_config=self._config) target_q = q.deepcopy(f'target_q{i}') assert isinstance(target_q, QFunction) self._q_ensembles.append(q) self._q_solvers[q.scope_name] = q_solver_builder(env_info=self._env_info, algorithm_config=self._config) self._target_q_ensembles.append(target_q) self._vae = vae_builder(scope_name="vae", env_info=self._env_info, algorithm_config=self._config) self._vae_solver = vae_solver_builder(env_info=self._env_info, algorithm_config=self._config) self._xi = perturbator_builder(scope_name="xi", env_info=self._env_info, algorithm_config=self._config) self._xi_solver = perturbator_solver_builder(env_info=self._env_info, algorithm_config=self._config) self._target_xi = perturbator_builder( scope_name="target_xi", env_info=self._env_info, algorithm_config=self._config) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): if has_batch_dimension(state, self._env_info): raise RuntimeError(f'{self.__name__} does not support batched state!') with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): state = add_batch_dimension(state) if not hasattr(self, '_eval_state_var'): repeat_num = 100 self._eval_state_var = create_variable(1, self._env_info.state_shape) if isinstance(self._eval_state_var, tuple): state_var = tuple(RF.repeat(x=s_var, repeats=repeat_num, axis=0) for s_var in self._eval_state_var) else: state_var = RF.repeat(x=self._eval_state_var, repeats=repeat_num, axis=0) assert state_var.shape == (repeat_num, self._eval_state_var.shape[1]) actions = self._vae.decode(z=None, state=state_var) noise = self._xi.generate_noise(state_var, actions, self._config.phi) self._eval_action = actions + noise q_values = self._q_ensembles[0].q(state_var, self._eval_action) self._eval_max_index = RF.argmax(q_values, axis=0) set_data_to_variable(self._eval_state_var, state) nn.forward_all([self._eval_action, self._eval_max_index]) return self._eval_action.d[self._eval_max_index.d[0]] def _before_training_start(self, env_or_buffer): # set context globally to ensure that the training runs on configured gpu context.set_nnabla_context(self._config.gpu_id) self._encoder_trainer = self._setup_encoder_training(env_or_buffer) self._q_function_trainer = self._setup_q_function_training(env_or_buffer) self._perturbator_trainer = self._setup_perturbator_training(env_or_buffer) def _setup_encoder_training(self, env_or_buffer): trainer_config = MT.encoder_trainers.KLDVariationalAutoEncoderTrainerConfig() encoder_trainer = MT.encoder_trainers.KLDVariationalAutoEncoderTrainer( models=self._vae, solvers={self._vae.scope_name: self._vae_solver}, env_info=self._env_info, config=trainer_config) return encoder_trainer def _setup_q_function_training(self, env_or_buffer): trainer_config = MT.q_value.BCQQTrainerConfig(reduction_method='mean', num_action_samples=self._config.num_action_samples, lmb=self._config.lmb) # This is a wrapper class which outputs the target action for next state in q function training class PerturbedPolicy(DeterministicPolicy): def __init__(self, vae, perturbator, phi): self._vae = vae self._perturbator = perturbator self._phi = phi def pi(self, s): a = self._vae.decode(z=None, state=s) noise = self._perturbator.generate_noise(s, a, phi=self._phi) return a + noise target_policy = PerturbedPolicy(self._vae, self._target_xi, self._config.phi) q_function_trainer = MT.q_value.BCQQTrainer( train_functions=self._q_ensembles, solvers=self._q_solvers, target_functions=self._target_q_ensembles, target_policy=target_policy, env_info=self._env_info, config=trainer_config) for q, target_q in zip(self._q_ensembles, self._target_q_ensembles): sync_model(q, target_q, 1.0) return q_function_trainer def _setup_perturbator_training(self, env_or_buffer): trainer_config = MT.perturbator_trainers.BCQPerturbatorTrainerConfig( phi=self._config.phi ) perturbator_trainer = MT.perturbator.BCQPerturbatorTrainer( models=self._xi, solvers={self._xi.scope_name: self._xi_solver}, q_function=self._q_ensembles[0], vae=self._vae, env_info=self._env_info, config=trainer_config) sync_model(self._xi, self._target_xi, 1.0) return perturbator_trainer def _run_online_training_iteration(self, env): raise NotImplementedError('BCQ does not support online training') def _run_offline_training_iteration(self, buffer): self._bcq_training(buffer) def _bcq_training(self, replay_buffer): experiences, info = replay_buffer.sample(self._config.batch_size) (s, a, r, non_terminal, s_next, *_) = marshal_experiences(experiences) batch = TrainingBatch(batch_size=self._config.batch_size, s_current=s, a_current=a, gamma=self._config.gamma, reward=r, non_terminal=non_terminal, s_next=s_next, weight=info['weights']) # Train vae self._encoder_trainer_state = self._encoder_trainer.train(batch) self._q_function_trainer_state = self._q_function_trainer.train(batch) for q, target_q in zip(self._q_ensembles, self._target_q_ensembles): sync_model(q, target_q, tau=self._config.tau) td_errors = self._q_function_trainer_state['td_errors'] replay_buffer.update_priorities(td_errors) self._perturbator_trainer.train(batch) sync_model(self._xi, self._target_xi, tau=self._config.tau) self._perturbator_trainer_state = self._perturbator_trainer.train(batch) def _models(self): models = [*self._q_ensembles, *self._target_q_ensembles, self._vae, self._xi, self._target_xi] return {model.scope_name: model for model in models} def _solvers(self): solvers = {} solvers.update(self._q_solvers) solvers[self._vae.scope_name] = self._vae_solver solvers[self._xi.scope_name] = self._xi_solver return solvers
[docs] @classmethod def is_supported_env(cls, env_or_env_info): env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ else env_or_env_info return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()
@property def latest_iteration_state(self): latest_iteration_state = super(BCQ, self).latest_iteration_state if hasattr(self, '_encoder_trainer_state'): latest_iteration_state['scalar'].update( {'encoder_loss': float(self._encoder_trainer_state['encoder_loss'])}) if hasattr(self, '_perturbator_trainer_state'): latest_iteration_state['scalar'].update( {'perturbator_loss': float(self._perturbator_trainer_state['perturbator_loss'])}) if hasattr(self, '_q_function_trainer_state'): latest_iteration_state['scalar'].update({'q_loss': float(self._q_function_trainer_state['q_loss'])}) latest_iteration_state['histogram'].update( {'td_errors': self._q_function_trainer_state['td_errors'].flatten()}) return latest_iteration_state @property def trainers(self): return { "encoder": self._encoder_trainer, "q_function": self._q_function_trainer, "perturbator": self._perturbator_trainer, }
if __name__ == "__main__": import nnabla_rl.environments as E env = E.DummyContinuous() bcq = BCQ(env)