# Copyright 2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Tuple, Union, cast
import gym
import numpy as np
import nnabla as nn
import nnabla.solvers as NS
import nnabla_rl.environment_explorers as EE
import nnabla_rl.model_trainers as MT
from nnabla_rl.algorithm import Algorithm, AlgorithmConfig, eval_api
from nnabla_rl.algorithms.common_utils import (
_EpsilonGreedyOptionSelector,
_GreedyOptionSelector,
_RandomOptionSelector,
_StochasticIntraPolicyActionSelector,
)
from nnabla_rl.builders import ExplorerBuilder, ModelBuilder, ReplayBufferBuilder, SolverBuilder
from nnabla_rl.environment_explorer import EnvironmentExplorer
from nnabla_rl.environments.environment_info import EnvironmentInfo
from nnabla_rl.model_trainers.model_trainer import ModelTrainer, TrainingBatch
from nnabla_rl.models import (
AtariOptionCriticIntraPolicy,
AtariOptionCriticOptionVFunction,
AtariOptionCriticTerminationFunction,
OptionCriticSharedFunctionHead,
OptionValueFunction,
StochasticIntraPolicy,
StochasticTerminationFunction,
)
from nnabla_rl.replay_buffer import ReplayBuffer
from nnabla_rl.typing import Experience
from nnabla_rl.utils import context
from nnabla_rl.utils.data import marshal_experiences
from nnabla_rl.utils.misc import sync_model
[docs]@dataclass
class OptionCriticConfig(AlgorithmConfig):
"""List of configurations for Option Critic Architecture algorithm.
Args:
gamma (float): discount factor of rewards. Defaults to 0.99.
intra_policy_learning_rate (float): learning rate which is set to intra policy solver. \
You can customize/override the learning rate for each solver by implementing the \
(:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`) by yourself. \
Defaults to 0.00025.
termination_function_learning_rate (float): learning rate which is set to termination function solver. \
You can customize/override the learning rate for each solver by implementing the \
(:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`) by yourself. \
Defaults to 0.00025.
option_v_function__learning_rate (float): learning rate which is set to option value function sulver. \
You can customize/override the learning rate for each solver by implementing the \
(:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`) by yourself. \
Defaults to 0.00025.
option_v_batch_size (int): training batch size of option value function. Defaults to 32.
termination_function_batch_size (int): training batch size of termination function function. Defaults to 1.
intra_policy_batch_size (int): training batch size of intra policy. Defaults to 1.
learner_update_frequency (int): the interval of learner update. Defaults to 4.
target_update_frequency (int): the interval of target q-function update. Defaults to 10000.
start_timesteps (int): the timestep when training starts.\
The algorithm will collect experiences from the environment by acting randomly until this timestep.
Defaults to 50000.
replay_buffer_size (int): the capacity of replay buffer. Defaults to 1000000.
max_explore_steps (int): the number of steps decaying the epsilon value.\
The epsilon will be decayed linearly \
:math:`\\epsilon=\\epsilon_{init} - step\\times\\frac{\\epsilon_{init} - \
\\epsilon_{final}}{max\\_explore\\_steps}`.\
Defaults to 1000000.
initial_epsilon (float): the initial epsilon value for ε-greedy explorer. Defaults to 1.0.
final_epsilon (float): the last epsilon value for ε-greedy explorer. Defaults to 0.1.
test_epsilon (float): the epsilon value on testing. Defaults to 0.05.
advantage_offset (float): advantage offset value for termination function learning. Defaults to 0.01.
entropy_regularizer_coefficient (float): scalar of entropy regularization term of intra policy learning. \
Defaults to 0.01.
use_baseline (bool): If True, subtracting the baseline value from the q value in intra policy learning. \
Defaults to True.
num_options (int): number of options. Defaults to 8.
option_v_loss_reduction_method (str): The reduction method for option v function loss. Defaults to 'sum'.
intra_policy_loss_reduction_method (str): The reduction method for intra policy loss. Defaults to 'mean'.
termination_function_loss_reduction_method (str): The reduction method for termination function loss. \
Defaults to 'mean'.
deterministic_termination_in_eval (bool): If true, terminates deterministically at evalution. Defaults to False.
deterministic_intra_action_in_eval (bool): If true, act deterministically at evalution. Defaults to False.
"""
gamma: float = 0.99
intra_policy_learning_rate: float = 2.5e-4
termination_function_learning_rate: float = 2.5e-4
option_v_function_learning_rate: float = 2.5e-4
option_v_batch_size: int = 32
termination_function_batch_size: int = 1
intra_policy_batch_size: int = 1
learner_update_frequency: float = 4
target_update_frequency: float = 10000
start_timesteps: int = 50000
replay_buffer_size: int = 1000000
max_option_explore_steps: int = 1000000
initial_option_epsilon: float = 1.0
final_option_epsilon: float = 0.1
test_option_epsilon: float = 0.05
advantage_offset: float = 0.01
entropy_regularizer_coefficient: float = 0.01
use_baseline: bool = True
num_options: int = 8
option_v_loss_reduction_method: str = "sum"
intra_policy_loss_reduction_method: str = "mean"
termination_function_loss_reduction_method: str = "mean"
deterministic_termination_in_eval: bool = False
deterministic_intra_action_in_eval: bool = False
def __post_init__(self):
"""__post_init__
Check set values are in valid range.
"""
self._assert_between(self.gamma, 0.0, 1.0, "gamma")
self._assert_positive(self.intra_policy_learning_rate, "intra_policy_learning_rate")
self._assert_positive(self.option_v_function_learning_rate, "option_v_function_learning_rate")
self._assert_positive(self.termination_function_learning_rate, "termination_function_learning_rate")
self._assert_positive(self.option_v_batch_size, "option_v_batch_size")
self._assert_positive(self.termination_function_batch_size, "termination_function_batch_size")
self._assert_positive(self.intra_policy_batch_size, "intra_policy_batch_size")
self._assert_positive(self.learner_update_frequency, "learner_update_frequency")
self._assert_positive(self.target_update_frequency, "target_update_frequency")
self._assert_positive(self.start_timesteps, "start_timesteps")
self._assert_positive(self.replay_buffer_size, "replay_buffer_size")
self._assert_smaller_than(self.start_timesteps, self.replay_buffer_size, "start_timesteps")
self._assert_between(self.initial_option_epsilon, 0.0, 1.0, "initial_option_epsilon")
self._assert_between(self.final_option_epsilon, 0.0, 1.0, "final_option_epsilon")
self._assert_between(self.test_option_epsilon, 0.0, 1.0, "test_option_epsilon")
self._assert_positive(self.max_option_explore_steps, "max_option_explore_steps")
self._assert_positive(self.num_options, "num_options")
self._assert_positive_or_zero(self.advantage_offset, "advantage_offset")
self._assert_positive_or_zero(self.entropy_regularizer_coefficient, "entropy_regularizer_coefficient")
self._assert_one_of(self.option_v_loss_reduction_method, ["sum", "mean"], "option_v_loss_reduction_method")
self._assert_one_of(
self.intra_policy_loss_reduction_method, ["sum", "mean"], "intra_policy_loss_reduction_method"
)
self._assert_one_of(
self.termination_function_loss_reduction_method,
["sum", "mean"],
"termination_function_loss_reduction_method",
)
class DefaultOptionValueFunctionBuilder(ModelBuilder[OptionValueFunction]):
def build_model( # type: ignore[override]
self,
scope_name: str,
env_info: EnvironmentInfo,
algorithm_config: OptionCriticConfig,
**kwargs,
) -> OptionValueFunction:
# scope name is same as that of termination function and intra policy
# -> parameter is shared across models automatically
_shared_function_head = OptionCriticSharedFunctionHead(
scope_name="shared", state_shape=env_info.state_shape, action_dim=env_info.action_dim
)
return AtariOptionCriticOptionVFunction(
scope_name="shared", # shared feature function should be updated via option value function loss only
head=_shared_function_head,
num_options=algorithm_config.num_options,
)
class DefaultIntraPolicyBuilder(ModelBuilder[StochasticIntraPolicy]):
def build_model( # type: ignore[override]
self,
scope_name: str,
env_info: EnvironmentInfo,
algorithm_config: OptionCriticConfig,
**kwargs,
) -> StochasticIntraPolicy:
assert scope_name != "shared"
# scope name is same as that of option v function and termination function
# -> parameter is shared across models automatically
_shared_function_head = OptionCriticSharedFunctionHead(
scope_name="shared", state_shape=env_info.state_shape, action_dim=env_info.action_dim
)
return AtariOptionCriticIntraPolicy(
scope_name=scope_name,
head=_shared_function_head,
num_options=algorithm_config.num_options,
action_dim=env_info.action_dim,
)
class DefaultTerminationFunctionBuilder(ModelBuilder[StochasticTerminationFunction]):
def build_model( # type: ignore[override]
self,
scope_name: str,
env_info: EnvironmentInfo,
algorithm_config: OptionCriticConfig,
**kwargs,
) -> StochasticTerminationFunction:
assert scope_name != "shared"
# scope name is same as that of option v function and intra policy
# -> parameter is shared across models automatically
_shared_function_head = OptionCriticSharedFunctionHead(
scope_name="shared", state_shape=env_info.state_shape, action_dim=env_info.action_dim
)
return AtariOptionCriticTerminationFunction(
scope_name=scope_name, head=_shared_function_head, num_options=algorithm_config.num_options
)
class DefaultIntraPolicySolverBuilder(SolverBuilder):
def build_solver( # type: ignore[override]
self, env_info: EnvironmentInfo, algorithm_config: OptionCriticConfig, **kwargs
) -> nn.solver.Solver:
solver = NS.Sgd(lr=algorithm_config.intra_policy_learning_rate)
return solver
class DefaultTerminationFunctionSolverBuilder(SolverBuilder):
def build_solver( # type: ignore[override]
self, env_info: EnvironmentInfo, algorithm_config: OptionCriticConfig, **kwargs
) -> nn.solver.Solver:
solver = NS.Sgd(lr=algorithm_config.termination_function_learning_rate)
return solver
class DefaultOptionVFunctionSolverBuilder(SolverBuilder):
def build_solver( # type: ignore[override]
self, env_info: EnvironmentInfo, algorithm_config: OptionCriticConfig, **kwargs
) -> nn.solver.Solver:
# this decay is equivalent to 'gradient momentum' and 'squared gradient momentum' of the nature paper
decay: float = 0.95
momentum: float = 0.0
min_squared_gradient: float = 0.01
solver = NS.RMSpropGraves(
lr=algorithm_config.option_v_function_learning_rate,
decay=decay,
momentum=momentum,
eps=min_squared_gradient,
)
return solver
class DefaultReplayBufferBuilder(ReplayBufferBuilder):
def build_replay_buffer( # type: ignore[override]
self, env_info: EnvironmentInfo, algorithm_config: OptionCriticConfig, **kwargs
) -> ReplayBuffer:
return ReplayBuffer(capacity=algorithm_config.replay_buffer_size)
class DefaultExplorerBuilder(ExplorerBuilder):
def build_explorer( # type: ignore[override]
self,
env_info: EnvironmentInfo,
algorithm_config: OptionCriticConfig,
algorithm: "OptionCritic",
**kwargs,
) -> EnvironmentExplorer:
explorer_config = EE.LinearDecayEpsilonGreedyOptionExplorerConfig(
warmup_random_steps=algorithm_config.start_timesteps,
timelimit_as_terminal=True,
initial_step_num=algorithm.iteration_num,
initial_option_epsilon=algorithm_config.initial_option_epsilon,
final_option_epsilon=algorithm_config.final_option_epsilon,
max_option_explore_steps=algorithm_config.max_option_explore_steps,
num_options=algorithm_config.num_options,
append_explorer_info=True,
)
explorer = EE.LinearDecayEpsilonGreedyOptionExplorer(
env_info=env_info,
config=explorer_config,
random_option_selector=algorithm._exploration_random_option_selector,
greedy_option_selector=algorithm._exploration_greedy_option_selector,
intra_action_selector=algorithm._intra_action_selector,
)
return explorer
[docs]class OptionCritic(Algorithm):
"""Option Critic algorithm.
This class implements the Option Critic Architecture algorithm
proposed by Pierre-Luc Bacon, et al. in the paper: "The Option-Critic Architecture"
For details see: https://arxiv.org/abs/1609.05140
Args:
env_or_env_info\
(gym.Env or :py:class:`EnvironmentInfo <nnabla_rl.environments.environment_info.EnvironmentInfo>`):
the environment to train or environment info
config (:py:class:`OptionCriticConfig <nnabla_rl.algorithms.option_critic.OptionCriticConfig>`):\
configuration of Option Critic algorithm
option_v_func_builder (:py:class:`ModelBuilder[OptionValueFunction] \
<nnabla_rl.builders.ModelBuilder>`): buider of option value function model
option_v_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`):
builder for option value function solver
intra_policy_builder (:py:class:`ModelBuilder[IntraPolicy] \
<nnabla_rl.builders.ModelBuilder>`): buider of intra policy function model
intra_policy_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`):
builder for option value function solver
termination_function_builder (:py:class:`ModelBuilder[TerminationFunction] \
<nnabla_rl.builders.ModelBuilder>`): buider of termination function model
termination_function_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`):
builder for termination function solver
replay_buffer_builder (:py:class:`ReplayBufferBuilder <nnabla_rl.builders.ReplayBufferBuilder>`):
builder of replay_buffer
explorer_builder (:py:class:`ExplorerBuilder <nnabla_rl.builders.ExplorerBuilder>`):
builder of environment explorer
"""
# type declarations to type check with mypy
# NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar
# See https://mypy.readthedocs.io/en/stable/class_basics.html for details
_config: OptionCriticConfig
_option_v_function: OptionValueFunction
_option_v_function_solver: nn.solver.Solver
_target_option_v_function: OptionValueFunction
_intra_policy: StochasticIntraPolicy
_intra_policy_solver: nn.solver.Solver
_termination_function: StochasticTerminationFunction
_termination_function_solver: nn.solver.Solver
_option_v_replay_buffer: ReplayBuffer
_explorer_builder: ExplorerBuilder
_environment_explorer: EnvironmentExplorer
_option_v_function_trainer: ModelTrainer
_option_v_function_trainer_state: Dict[str, Any]
_termination_function_trainer: ModelTrainer
_termination_function_trainer_state: Dict[str, Any]
_intra_policy_trainer: ModelTrainer
_intra_policy_trainer_state: Dict[str, Any]
_evaluation_greedy_option_actor: _GreedyOptionSelector
_exploration_greedy_option_actor: _GreedyOptionSelector
_evaluation_random_option_actor: _RandomOptionSelector
_exploration_random_option_actor: _RandomOptionSelector
_intra_action_actor: _StochasticIntraPolicyActionSelector
_evaluation_option_actor: _EpsilonGreedyOptionSelector
def __init__(
self,
env_or_env_info: Union[gym.Env, EnvironmentInfo],
config: OptionCriticConfig = OptionCriticConfig(),
option_v_func_builder: ModelBuilder[OptionValueFunction] = DefaultOptionValueFunctionBuilder(),
option_v_solver_builder: SolverBuilder = DefaultOptionVFunctionSolverBuilder(),
intra_policy_builder: ModelBuilder[StochasticIntraPolicy] = DefaultIntraPolicyBuilder(),
intra_policy_solver_builder: SolverBuilder = DefaultIntraPolicySolverBuilder(),
termination_func_builder: ModelBuilder[StochasticTerminationFunction] = DefaultTerminationFunctionBuilder(),
termination_solver_builder: SolverBuilder = DefaultTerminationFunctionSolverBuilder(),
replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(),
explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(),
):
super(OptionCritic, self).__init__(env_or_env_info, config=config)
self._explorer_builder = explorer_builder
with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)):
self._option_v_function = option_v_func_builder(
scope_name="shared", env_info=self._env_info, algorithm_config=self._config
)
self._option_v_function_solver = option_v_solver_builder(
env_info=self._env_info, algorithm_config=self._config
)
self._target_option_v_function = self._option_v_function.deepcopy(
"target_" + self._option_v_function.scope_name
)
self._termination_function = termination_func_builder(
scope_name="termination_func", env_info=self._env_info, algorithm_config=self._config
)
self._termination_function_solver = termination_solver_builder(
env_info=self._env_info, algorithm_config=self._config
)
self._intra_policy = intra_policy_builder(
scope_name="intra_policy", env_info=self._env_info, algorithm_config=self._config
)
self._intra_policy_solver = intra_policy_solver_builder(
env_info=self._env_info, algorithm_config=self._config
)
self._option_v_replay_buffer = replay_buffer_builder(env_info=self._env_info, algorithm_config=self._config)
self._termination_function_replay_buffer = ReplayBuffer(
capacity=self._config.termination_function_batch_size
)
self._intra_policy_replay_buffer = ReplayBuffer(capacity=self._config.intra_policy_batch_size)
self._environment_explorer = explorer_builder(
env_info=self._env_info, algorithm_config=self._config, algorithm=self
)
self._evaluation_greedy_option_actor = _GreedyOptionSelector(
self._config.num_options,
self._env_info,
self._option_v_function,
self._termination_function,
deterministic_termination=self._config.deterministic_termination_in_eval,
)
self._exploration_greedy_option_actor = _GreedyOptionSelector(
self._config.num_options, self._env_info, self._option_v_function, self._termination_function
)
self._evaluation_random_option_actor = _RandomOptionSelector(
self._config.num_options,
self._env_info,
self._termination_function,
deterministic_termination=self._config.deterministic_termination_in_eval,
)
self._exploration_random_option_actor = _RandomOptionSelector(
self._config.num_options,
self._env_info,
self._termination_function,
)
self._intra_action_actor = _StochasticIntraPolicyActionSelector(
self._env_info, deterministic=self._config.deterministic_intra_action_in_eval, policy=self._intra_policy
)
self._evaluation_option_actor = _EpsilonGreedyOptionSelector(
greedy_option_selector=self._evaluation_greedy_option_selector,
random_option_selector=self._evaluation_random_option_selector,
epsilon=self._config.test_option_epsilon,
)
@eval_api
def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}):
with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)):
option, _ = self._evaluation_option_actor(state, begin_of_episode=begin_of_episode)
action, _ = self._intra_action_selector(state, option, begin_of_episode=begin_of_episode)
return action
def _before_training_start(self, env_or_buffer):
# set context globally to ensure that the training runs on configured gpu
context.set_nnabla_context(self._config.gpu_id)
self._environment_explorer = self._setup_environment_explorer(env_or_buffer)
self._option_v_function_trainer = self._setup_option_v_function_training(env_or_buffer)
self._intra_policy_function_trainer = self._setup_intra_policy_training(env_or_buffer)
self._termination_function_trainer = self._setup_termination_function_training(env_or_buffer)
def _setup_environment_explorer(self, env_or_buffer):
return self._explorer_builder(self._env_info, self._config, self)
def _setup_option_v_function_training(self, env_or_buffer):
trainer_config = MT.option_value_trainers.OptionCriticOptionValueTrainerConfig(
reduction_method=self._config.option_v_loss_reduction_method
)
option_v_function_trainer = MT.option_value_trainers.OptionCriticOptionValueTrainer(
train_functions=self._option_v_function,
solvers={self._option_v_function.scope_name: self._option_v_function_solver},
target_function=self._target_option_v_function,
env_info=self._env_info,
termination_functions=self._termination_function,
config=trainer_config,
)
sync_model(self._option_v_function, self._target_option_v_function)
return option_v_function_trainer
def _setup_intra_policy_training(self, env_or_buffer):
trainer_config = MT.intra_policy_trainers.OptionCriticIntraPolicyTrainerConfig(
entropy_coefficient=self._config.entropy_regularizer_coefficient,
reduction_method=self._config.intra_policy_loss_reduction_method,
)
intra_policy_trainer = MT.intra_policy_trainers.OptionCriticIntraPolicyTrainer(
models=self._intra_policy,
solvers={self._intra_policy.scope_name: self._intra_policy_solver},
env_info=self._env_info,
termination_functions=self._termination_function,
target_option_v_function=self._target_option_v_function,
option_v_functions=self._option_v_function,
config=trainer_config,
)
return intra_policy_trainer
def _setup_termination_function_training(self, env_or_buffer):
trainer_config = MT.termination_trainers.OptionCriticTerminationFunctionTrainerConfig(
advantage_offset=self._config.advantage_offset,
reduction_method=self._config.termination_function_loss_reduction_method,
)
termination_function_trainer = MT.termination_trainers.OptionCriticTerminationFunctionTrainer(
models=self._termination_function,
solvers={self._termination_function.scope_name: self._termination_function_solver},
env_info=self._env_info,
option_v_functions=self._option_v_function,
config=trainer_config,
)
return termination_function_trainer
def _run_online_training_iteration(self, env):
experiences = self._environment_explorer.step(env)
for e in experiences:
s, a, r, non_terminal, n_s, info = e
assert "option" in info
self._option_v_replay_buffer.append((s, a, r, non_terminal, n_s, info["option"]))
if self._config.start_timesteps < self.iteration_num:
self._intra_policy_replay_buffer.append((s, a, r, non_terminal, n_s, info["option"]))
self._termination_function_replay_buffer.append((s, a, r, non_terminal, n_s, info["option"]))
if self._config.start_timesteps < self.iteration_num:
if len(self._intra_policy_replay_buffer) >= self._config.intra_policy_batch_size:
self._intra_policy_training(self._intra_policy_replay_buffer)
# Clear buffer after training
self._intra_policy_replay_buffer = ReplayBuffer(capacity=self._config.intra_policy_batch_size)
if len(self._termination_function_replay_buffer) >= self._config.intra_policy_batch_size:
self._termination_training(self._termination_function_replay_buffer)
# Clear buffer after training
self._termination_function_replay_buffer = ReplayBuffer(
capacity=self._config.termination_function_batch_size
)
if self.iteration_num % self._config.learner_update_frequency == 0:
# off-policy training
self._option_v_training(self._option_v_replay_buffer)
def _run_offline_training_iteration(self, buffer):
raise NotImplementedError
def _evaluation_greedy_option_selector(self, s, option, *, begin_of_episode=False):
return self._evaluation_greedy_option_actor(s, option, begin_of_episode=begin_of_episode)
def _evaluation_random_option_selector(self, s, option, *, begin_of_episode=False):
return self._evaluation_random_option_actor(s, option, begin_of_episode=begin_of_episode)
def _exploration_greedy_option_selector(self, s, option, *, begin_of_episode=False):
return self._exploration_greedy_option_actor(s, option, begin_of_episode=begin_of_episode)
def _exploration_random_option_selector(self, s, option, *, begin_of_episode=False):
return self._exploration_random_option_actor(s, option, begin_of_episode=begin_of_episode)
def _intra_action_selector(self, s, option, *, begin_of_episode=False):
return self._intra_action_actor(s, option, begin_of_episode=begin_of_episode)
def _option_v_training(self, replay_buffer: ReplayBuffer):
experiences, _ = replay_buffer.sample(self._config.option_v_batch_size)
experiences = cast(Tuple[Experience], experiences)
s, a, r, non_terminal, s_next, option = marshal_experiences(experiences)
batch = TrainingBatch(
batch_size=self._config.option_v_batch_size,
s_current=s,
a_current=a,
gamma=self._config.gamma,
reward=r,
non_terminal=non_terminal,
s_next=s_next,
extra={"option": option},
)
self._option_v_function_trainer_state = self._option_v_function_trainer.train(batch)
if self.iteration_num % self._config.target_update_frequency == 0:
sync_model(self._option_v_function, self._target_option_v_function)
def _termination_training(self, replay_buffer: ReplayBuffer):
experiences, _ = replay_buffer.sample_indices(np.arange(self._config.termination_function_batch_size).tolist())
experiences = cast(Tuple[Experience], experiences)
s, a, r, non_terminal, s_next, option = marshal_experiences(experiences)
batch = TrainingBatch(
batch_size=len(experiences),
s_current=s,
a_current=a,
gamma=self._config.gamma,
reward=r,
non_terminal=non_terminal,
s_next=s_next,
extra={"option": option},
)
self._termination_function_trainer_state = self._termination_function_trainer.train(batch)
def _intra_policy_training(self, replay_buffer: ReplayBuffer):
experiences, _ = replay_buffer.sample_indices(np.arange(self._config.intra_policy_batch_size).tolist())
experiences = cast(Tuple[Experience], experiences)
s, a, r, non_terminal, s_next, option = marshal_experiences(experiences)
batch = TrainingBatch(
batch_size=len(experiences),
s_current=s,
a_current=a,
gamma=self._config.gamma,
reward=r,
non_terminal=non_terminal,
s_next=s_next,
extra={"option": option},
)
self._intra_policy_trainer_state = self._intra_policy_function_trainer.train(batch)
def _models(self):
models = {}
models[self._option_v_function.scope_name] = self._option_v_function
models[self._termination_function.scope_name] = self._termination_function
models[self._intra_policy.scope_name] = self._intra_policy
return models
def _solvers(self):
solvers = {}
solvers[self._option_v_function.scope_name] = self._option_v_function_solver
solvers[self._termination_function.scope_name] = self._termination_function_solver
solvers[self._intra_policy.scope_name] = self._intra_policy_solver
return solvers
[docs] @classmethod
def is_supported_env(cls, env_or_env_info):
env_info = (
EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info
)
return not env_info.is_continuous_action_env() and not env_info.is_tuple_action_env()
[docs] @classmethod
def is_rnn_supported(self):
return False
@property
def latest_iteration_state(self):
latest_iteration_state = super(OptionCritic, self).latest_iteration_state
if hasattr(self, "_option_v_function_trainer_state"):
latest_iteration_state["scalar"].update(
{"option_v_loss": float(self._option_v_function_trainer_state["option_v_loss"])}
)
if hasattr(self, "_termination_function_trainer_state"):
latest_iteration_state["scalar"].update(
{"termination_loss": float(self._termination_function_trainer_state["termination_loss"])}
)
if hasattr(self, "_intra_policy_trainer_state"):
latest_iteration_state["scalar"].update(
{"intra_pi_loss": float(self._intra_policy_trainer_state["intra_pi_loss"])}
)
return latest_iteration_state
@property
def trainers(self):
return {
"option_v_function": self._option_v_function_trainer,
"intra_policy": self._intra_policy_trainer,
"termination_function": self._termination_function_trainer,
}