# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import gym
import nnabla as nn
import nnabla.solvers as NS
import nnabla_rl.functions as RF
import nnabla_rl.model_trainers as MT
from nnabla_rl.algorithm import Algorithm, AlgorithmConfig, eval_api
from nnabla_rl.algorithms.common_utils import has_batch_dimension
from nnabla_rl.builders import ModelBuilder, SolverBuilder
from nnabla_rl.environments.environment_info import EnvironmentInfo
from nnabla_rl.model_trainers.model_trainer import ModelTrainer, TrainingBatch
from nnabla_rl.models import (
BCQPerturbator,
BCQVariationalAutoEncoder,
DeterministicPolicy,
Perturbator,
QFunction,
TD3QFunction,
VariationalAutoEncoder,
)
from nnabla_rl.utils import context
from nnabla_rl.utils.data import add_batch_dimension, marshal_experiences, set_data_to_variable
from nnabla_rl.utils.misc import create_variable, sync_model
[docs]@dataclass
class BCQConfig(AlgorithmConfig):
"""BCQConfig List of configurations for BCQ algorithm.
Args:
gamma (float): discount factor of reward. Defaults to 0.99.
learning_rate (float): learning rate which is set to all solvers. \
You can customize/override the learning rate for each solver by implementing the \
(:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`) by yourself. \
Defaults to 0.001.
batch_size (int): training batch size. Defaults to 100.
tau (float): target network's parameter update coefficient. Defaults to 0.005.
lmb (float): weight :math:`\\lambda` used for balancing the ratio between :math:`\\min{Q}` and :math:`\\max{Q}`\
on target q value generation (i.e. :math:`\\lambda\\min{Q} + (1 - \\lambda)\\max{Q}`).\
Defaults to 0.75.
phi (float): action perturbator noise coefficient. Defaults to 0.05.
num_q_ensembles (int): number of q function ensembles . Defaults to 2.
num_action_samples (int): number of actions to sample for computing target q values. Defaults to 10.
"""
gamma: float = 0.99
learning_rate: float = 1.0 * 1e-3
batch_size: int = 100
tau: float = 0.005
lmb: float = 0.75
phi: float = 0.05
num_q_ensembles: int = 2
num_action_samples: int = 10
def __post_init__(self):
"""__post_init__
Check set values are in valid range.
"""
self._assert_between(self.tau, 0.0, 1.0, "tau")
self._assert_between(self.gamma, 0.0, 1.0, "gamma")
self._assert_positive(self.lmb, "lmb")
self._assert_positive(self.phi, "phi")
self._assert_positive(self.num_q_ensembles, "num_q_ensembles")
self._assert_positive(self.num_action_samples, "num_action_samples")
self._assert_positive(self.batch_size, "batch_size")
class DefaultQFunctionBuilder(ModelBuilder[QFunction]):
def build_model( # type: ignore[override]
self,
scope_name: str,
env_info: EnvironmentInfo,
algorithm_config: BCQConfig,
**kwargs,
) -> QFunction:
return TD3QFunction(scope_name)
class DefaultVAEBuilder(ModelBuilder[VariationalAutoEncoder]):
def build_model( # type: ignore[override]
self,
scope_name: str,
env_info: EnvironmentInfo,
algorithm_config: BCQConfig,
**kwargs,
) -> VariationalAutoEncoder:
max_action_value = float(env_info.action_high[0])
return BCQVariationalAutoEncoder(
scope_name, env_info.state_dim, env_info.action_dim, env_info.action_dim * 2, max_action_value
)
class DefaultPerturbatorBuilder(ModelBuilder[Perturbator]):
def build_model( # type: ignore[override]
self,
scope_name: str,
env_info: EnvironmentInfo,
algorithm_config: BCQConfig,
**kwargs,
) -> Perturbator:
max_action_value = float(env_info.action_high[0])
return BCQPerturbator(scope_name, env_info.state_dim, env_info.action_dim, max_action_value)
class DefaultSolverBuilder(SolverBuilder):
def build_solver(self, env_info: EnvironmentInfo, algorithm_config: BCQConfig, **kwargs): # type: ignore[override]
return NS.Adam(alpha=algorithm_config.learning_rate)
[docs]class BCQ(Algorithm):
"""Batch-Constrained Q-learning (BCQ) algorithm.
This class implements the Batch-Constrained Q-learning (BCQ) algorithm
proposed by S. Fujimoto, et al. in the paper: "Off-Policy Deep Reinforcement Learning without Exploration"
For details see: https://arxiv.org/abs/1812.02900
This algorithm only supports offline training.
Args:
env_or_env_info \
(gym.Env or :py:class:`EnvironmentInfo <nnabla_rl.environments.environment_info.EnvironmentInfo>`):
the environment to train or environment info
config (:py:class:`BCQConfig <nnabla_rl.algorithms.bcq.BCQConfig>`):
configuration of the BCQ algorithm
q_function_builder (:py:class:`ModelBuilder[QFunction] <nnabla_rl.builders.ModelBuilder>`):
builder of q-function models
q_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`):
builder for q-function solvers
vae_builder (:py:class:`ModelBuilder[VariationalAutoEncoder] <nnabla_rl.builders.ModelBuilder>`):
builder of variational auto encoder models
vae_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`):
builder for variational auto encoder solvers
perturbator_builder (:py:class:`PerturbatorBuilder <nnabla_rl.builders.PerturbatorBuilder>`):
builder of perturbator models
perturbator_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`):
builder for perturbator solvers
"""
# type declarations to type check with mypy
# NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar
# See https://mypy.readthedocs.io/en/stable/class_basics.html for details
_config: BCQConfig
_q_ensembles: List[QFunction]
_q_solvers: Dict[str, nn.solver.Solver]
_target_q_ensembles: List[QFunction]
_vae: VariationalAutoEncoder
_vae_solver: nn.solver.Solver
_xi: Perturbator
_xi_solver: nn.solver.Solver
_q_function_trainer: ModelTrainer
_encoder_trainer: ModelTrainer
_perturbator_trainer: ModelTrainer
_eval_state_var: nn.Variable
_eval_action: nn.Variable
_eval_max_index: nn.Variable
_encoder_trainer_state: Dict[str, Any]
_q_function_trainer_state: Dict[str, Any]
_perturbator_trainer_state: Dict[str, Any]
def __init__(
self,
env_or_env_info: Union[gym.Env, EnvironmentInfo],
config: BCQConfig = BCQConfig(),
q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(),
q_solver_builder: SolverBuilder = DefaultSolverBuilder(),
vae_builder: ModelBuilder[VariationalAutoEncoder] = DefaultVAEBuilder(),
vae_solver_builder: SolverBuilder = DefaultSolverBuilder(),
perturbator_builder: ModelBuilder[Perturbator] = DefaultPerturbatorBuilder(),
perturbator_solver_builder: SolverBuilder = DefaultSolverBuilder(),
):
super(BCQ, self).__init__(env_or_env_info, config=config)
with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)):
self._q_ensembles = []
self._q_solvers = {}
self._target_q_ensembles = []
for i in range(self._config.num_q_ensembles):
q = q_function_builder(scope_name=f"q{i}", env_info=self._env_info, algorithm_config=self._config)
target_q = q.deepcopy(f"target_q{i}")
assert isinstance(target_q, QFunction)
self._q_ensembles.append(q)
self._q_solvers[q.scope_name] = q_solver_builder(env_info=self._env_info, algorithm_config=self._config)
self._target_q_ensembles.append(target_q)
self._vae = vae_builder(scope_name="vae", env_info=self._env_info, algorithm_config=self._config)
self._vae_solver = vae_solver_builder(env_info=self._env_info, algorithm_config=self._config)
self._xi = perturbator_builder(scope_name="xi", env_info=self._env_info, algorithm_config=self._config)
self._xi_solver = perturbator_solver_builder(env_info=self._env_info, algorithm_config=self._config)
self._target_xi = perturbator_builder(
scope_name="target_xi", env_info=self._env_info, algorithm_config=self._config
)
@eval_api
def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}):
if has_batch_dimension(state, self._env_info):
raise RuntimeError(f"{self.__name__} does not support batched state!")
with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)):
state = add_batch_dimension(state)
if not hasattr(self, "_eval_state_var"):
repeat_num = 100
self._eval_state_var = create_variable(1, self._env_info.state_shape)
if isinstance(self._eval_state_var, tuple):
state_var = tuple(RF.repeat(x=s_var, repeats=repeat_num, axis=0) for s_var in self._eval_state_var)
else:
state_var = RF.repeat(x=self._eval_state_var, repeats=repeat_num, axis=0)
assert state_var.shape == (repeat_num, self._eval_state_var.shape[1])
actions = self._vae.decode(z=None, state=state_var)
noise = self._xi.generate_noise(state_var, actions, self._config.phi)
self._eval_action = actions + noise
q_values = self._q_ensembles[0].q(state_var, self._eval_action)
self._eval_max_index = RF.argmax(q_values, axis=0)
set_data_to_variable(self._eval_state_var, state)
nn.forward_all([self._eval_action, self._eval_max_index])
return self._eval_action.d[self._eval_max_index.d[0]]
def _before_training_start(self, env_or_buffer):
# set context globally to ensure that the training runs on configured gpu
context.set_nnabla_context(self._config.gpu_id)
self._encoder_trainer = self._setup_encoder_training(env_or_buffer)
self._q_function_trainer = self._setup_q_function_training(env_or_buffer)
self._perturbator_trainer = self._setup_perturbator_training(env_or_buffer)
def _setup_encoder_training(self, env_or_buffer):
trainer_config = MT.encoder_trainers.KLDVariationalAutoEncoderTrainerConfig()
encoder_trainer = MT.encoder_trainers.KLDVariationalAutoEncoderTrainer(
models=self._vae,
solvers={self._vae.scope_name: self._vae_solver},
env_info=self._env_info,
config=trainer_config,
)
return encoder_trainer
def _setup_q_function_training(self, env_or_buffer):
trainer_config = MT.q_value.BCQQTrainerConfig(
reduction_method="mean", num_action_samples=self._config.num_action_samples, lmb=self._config.lmb
)
# This is a wrapper class which outputs the target action for next state in q function training
class PerturbedPolicy(DeterministicPolicy):
def __init__(self, vae, perturbator, phi):
self._vae = vae
self._perturbator = perturbator
self._phi = phi
def pi(self, s):
a = self._vae.decode(z=None, state=s)
noise = self._perturbator.generate_noise(s, a, phi=self._phi)
return a + noise
target_policy = PerturbedPolicy(self._vae, self._target_xi, self._config.phi)
q_function_trainer = MT.q_value.BCQQTrainer(
train_functions=self._q_ensembles,
solvers=self._q_solvers,
target_functions=self._target_q_ensembles,
target_policy=target_policy,
env_info=self._env_info,
config=trainer_config,
)
for q, target_q in zip(self._q_ensembles, self._target_q_ensembles):
sync_model(q, target_q, 1.0)
return q_function_trainer
def _setup_perturbator_training(self, env_or_buffer):
trainer_config = MT.perturbator_trainers.BCQPerturbatorTrainerConfig(phi=self._config.phi)
perturbator_trainer = MT.perturbator.BCQPerturbatorTrainer(
models=self._xi,
solvers={self._xi.scope_name: self._xi_solver},
q_function=self._q_ensembles[0],
vae=self._vae,
env_info=self._env_info,
config=trainer_config,
)
sync_model(self._xi, self._target_xi, 1.0)
return perturbator_trainer
def _run_online_training_iteration(self, env):
raise NotImplementedError("BCQ does not support online training")
def _run_offline_training_iteration(self, buffer):
self._bcq_training(buffer)
def _bcq_training(self, replay_buffer):
experiences, info = replay_buffer.sample(self._config.batch_size)
(s, a, r, non_terminal, s_next, *_) = marshal_experiences(experiences)
batch = TrainingBatch(
batch_size=self._config.batch_size,
s_current=s,
a_current=a,
gamma=self._config.gamma,
reward=r,
non_terminal=non_terminal,
s_next=s_next,
weight=info["weights"],
)
# Train vae
self._encoder_trainer_state = self._encoder_trainer.train(batch)
self._q_function_trainer_state = self._q_function_trainer.train(batch)
for q, target_q in zip(self._q_ensembles, self._target_q_ensembles):
sync_model(q, target_q, tau=self._config.tau)
td_errors = self._q_function_trainer_state["td_errors"]
replay_buffer.update_priorities(td_errors)
self._perturbator_trainer.train(batch)
sync_model(self._xi, self._target_xi, tau=self._config.tau)
self._perturbator_trainer_state = self._perturbator_trainer.train(batch)
def _models(self):
models = [*self._q_ensembles, *self._target_q_ensembles, self._vae, self._xi, self._target_xi]
return {model.scope_name: model for model in models}
def _solvers(self):
solvers = {}
solvers.update(self._q_solvers)
solvers[self._vae.scope_name] = self._vae_solver
solvers[self._xi.scope_name] = self._xi_solver
return solvers
[docs] @classmethod
def is_supported_env(cls, env_or_env_info):
env_info = (
EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info
)
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()
@property
def latest_iteration_state(self):
latest_iteration_state = super(BCQ, self).latest_iteration_state
if hasattr(self, "_encoder_trainer_state"):
latest_iteration_state["scalar"].update(
{"encoder_loss": float(self._encoder_trainer_state["encoder_loss"])}
)
if hasattr(self, "_perturbator_trainer_state"):
latest_iteration_state["scalar"].update(
{"perturbator_loss": float(self._perturbator_trainer_state["perturbator_loss"])}
)
if hasattr(self, "_q_function_trainer_state"):
latest_iteration_state["scalar"].update({"q_loss": float(self._q_function_trainer_state["q_loss"])})
latest_iteration_state["histogram"].update(
{"td_errors": self._q_function_trainer_state["td_errors"].flatten()}
)
return latest_iteration_state
@property
def trainers(self):
return {
"encoder": self._encoder_trainer,
"q_function": self._q_function_trainer,
"perturbator": self._perturbator_trainer,
}
if __name__ == "__main__":
import nnabla_rl.environments as E
env = E.DummyContinuous()
bcq = BCQ(env)