# Copyright 2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import gym
import numpy as np
import nnabla as nn
import nnabla.functions as NF
import nnabla.solvers as NS
import nnabla_rl.functions as RF
import nnabla_rl.model_trainers as MT
from nnabla_rl.algorithms.common_utils import _ActionSelector
from nnabla_rl.algorithms.td3 import TD3, DefaultSolverBuilder, TD3Config
from nnabla_rl.builders import ExplorerBuilder, ModelBuilder, ReplayBufferBuilder, SolverBuilder
from nnabla_rl.environment_explorer import EnvironmentExplorer, EnvironmentExplorerConfig
from nnabla_rl.environment_explorers import RawPolicyExplorer, RawPolicyExplorerConfig
from nnabla_rl.environments.environment_info import EnvironmentInfo
from nnabla_rl.model_trainers.model_trainer import TrainingBatch
from nnabla_rl.models import DeterministicPolicy, HyARPolicy, HyARQFunction, HyARVAE, QFunction
from nnabla_rl.replay_buffer import ReplayBuffer
from nnabla_rl.replay_buffers import ReplacementSamplingReplayBuffer
from nnabla_rl.utils import context
from nnabla_rl.utils.data import marshal_experiences, set_data_to_variable
from nnabla_rl.utils.misc import sync_model
from nnabla_rl.utils.solver_wrappers import AutoClipGradByNorm
[docs]@dataclass
class HyARConfig(TD3Config):
"""HyARConfig List of configurations for HyAR algorithm.
Args:
gamma (float): discount factor of rewards. Defaults to 0.99.
learning_rate (float): learning rate which is set to all solvers. \
You can customize/override the learning rate for each solver by implementing the \
(:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`) by yourself. \
Defaults to 0.003.
batch_size(int): training batch size. Defaults to 100.
tau (float): target network's parameter update coefficient. Defaults to 0.005.
start_timesteps (int): the timestep when training starts.\
The algorithm will collect experiences from the environment by acting randomly until this timestep.\
Defaults to 10000.
replay_buffer_size (int): capacity of the replay buffer. Defaults to 1000000.
d (int): Interval of the policy update. The policy will be updated every d q-function updates. Defaults to 2.
exploration_noise_sigma (float): Standard deviation of the gaussian exploration noise. Defaults to 0.1.
train_action_noise_sigma (float): Standard deviation of the gaussian action noise used in the training.\
Defaults to 0.5.
train_action_noise_abs (float): Absolute limit value of action noise used in the training. Defaults to 0.5.
noisy_action_max (float): Maximum value of the training action after appending the noise. Defaults to 1.0.
noisy_action_min (float): Minimum value of the training action after appending the noise. Defaults to -1.0.
num_steps (int): number of steps for N-step Q targets. Defaults to 1.
actor_unroll_steps (int): Number of steps to unroll actor's tranining network.\
The network will be unrolled even though the provided model doesn't have RNN layers.\
Defaults to 1.
actor_burn_in_steps (int): Number of burn-in steps to initiaze actor's recurrent layer states during training.\
This flag does not take effect if given model is not an RNN model.\
Defaults to 0.
actor_reset_rnn_on_terminal (bool): Reset actor's recurrent internal states to zero during training\
if episode ends. This flag does not take effect if given model is not an RNN model.\
Defaults to False.
critic_unroll_steps (int): Number of steps to unroll critic's tranining network.\
The network will be unrolled even though the provided model doesn't have RNN layers.\
Defaults to 1.
critic_burn_in_steps (int): Number of burn-in steps to initiaze critic's recurrent layer states\
during training. This flag does not take effect if given model is not an RNN model.\
Defaults to 0.
critic_reset_rnn_on_terminal (bool): Reset critic's recurrent internal states to zero during training\
if episode ends. This flag does not take effect if given model is not an RNN model.\
Defaults to False.
latent_dim (int): Latent action's dimension. Defaults to 6.\
embed_dim (int): Discrete action embedding's dimension. Defaults to 6.\
T (int): VAE training interval. VAE is trained every T episodes. Defaults to 10.\
vae_pretrain_episodes (float): Number of data collection episodes for vae pretraining.\
Defaults to 20000.\
vae_pretrain_batch_size (int): Batch size used in vae pretraining.\
Defaults to 64.\
vae_pretrain_times (int): VAE is updated for this number of iterations during the pretrain stage.\
Defaults to 5000.\
vae_training_batch_size (int): batch size used in vae training. Defaults to 64.\
vae_training_times (int): VAE is updated for this number of iterations every T steps. Defaults to 1.\
vae_learning_rate (float): VAE learning rate. Defaults to 1e-4.\
vae_buffer_size (int): Replay buffer size for VAE model. Defaults to 200000.\
latent_select_batch_size: (int): Batch size for computing latent space constraint (LSC). Defaults to 5000.\
latent_select_range: (float): Percentage of the latent variables in central range. Default to 96.\
noise_decay_steps (int): Exploration noise decay steps. Noise decays for this number of experienced episodes.\
Defaults to 1000.\
initial_exploration_noise (float): Initial standard deviation of exploration noise. Defaults to 1.0.
final_exploration_noise (float): Final standard deviation of exploration noise. Defaults to 0.1.
"""
train_action_noise_sigma: float = 0.1
train_action_noise_abs: float = 0.5
noisy_action_min: float = -1.0
noisy_action_max: float = -1.0
latent_dim: int = 6
embed_dim: int = 6
T: int = 10
vae_pretrain_episodes: int = 20000
vae_pretrain_batch_size: int = 64
vae_pretrain_times: int = 5000
vae_training_batch_size: int = 64
vae_training_times: int = 1
vae_learning_rate: float = 1e-4
vae_buffer_size: int = int(2e6)
latent_select_batch_size: int = 5000
latent_select_range: float = 96.0
noise_decay_steps: int = 1000
initial_exploration_noise: float = 1.0
final_exploration_noise: float = 0.1
def __post_init__(self):
self._assert_positive(self.latent_dim, "latent_dim")
self._assert_positive(self.embed_dim, "embed_dim")
self._assert_positive(self.T, "T")
self._assert_positive_or_zero(self.vae_pretrain_episodes, "vae_pretrain_episodes")
self._assert_positive(self.vae_pretrain_batch_size, "vae_pretrain_batch_size")
self._assert_positive_or_zero(self.vae_pretrain_times, "vae_pretrain_times")
self._assert_positive(self.vae_training_batch_size, "vae_training_batch_size")
self._assert_positive_or_zero(self.vae_training_times, "vae_training_times")
self._assert_positive_or_zero(self.vae_learning_rate, "vae_learning_rate")
self._assert_positive(self.vae_buffer_size, "vae_buffer_size")
self._assert_positive(self.latent_select_batch_size, "latent_select_batch_size")
self._assert_between(self.latent_select_range, 0, 100, "latent_select_range")
self._assert_positive(self.noise_decay_steps, "noise_decay_steps")
self._assert_positive(self.initial_exploration_noise, "initial_exploration_noise")
self._assert_positive(self.final_exploration_noise, "final_exploration_noise")
return super().__post_init__()
class DefaultCriticBuilder(ModelBuilder[QFunction]):
def build_model( # type: ignore[override]
self,
scope_name: str,
env_info: EnvironmentInfo,
algorithm_config: HyARConfig,
**kwargs,
) -> QFunction:
return HyARQFunction(scope_name)
class DefaultActorBuilder(ModelBuilder[DeterministicPolicy]):
def build_model( # type: ignore[override]
self,
scope_name: str,
env_info: EnvironmentInfo,
algorithm_config: HyARConfig,
**kwargs,
) -> DeterministicPolicy:
max_action_value = 1.0
action_dim = algorithm_config.latent_dim + algorithm_config.embed_dim
return HyARPolicy(scope_name, action_dim, max_action_value=max_action_value)
class DefaultVAEBuilder(ModelBuilder[HyARVAE]):
def build_model( # type: ignore[override]
self,
scope_name: str,
env_info: EnvironmentInfo,
algorithm_config: HyARConfig,
**kwargs,
) -> HyARVAE:
return HyARVAE(
scope_name,
state_dim=env_info.state_dim,
action_dim=env_info.action_dim,
encode_dim=algorithm_config.latent_dim,
embed_dim=algorithm_config.embed_dim,
)
class DefaultActorSolverBuilder(SolverBuilder):
def build_solver( # type: ignore[override]
self, env_info: EnvironmentInfo, algorithm_config: HyARConfig, **kwargs
) -> nn.solver.Solver:
solver = NS.Adam(alpha=algorithm_config.learning_rate)
return AutoClipGradByNorm(solver, 10.0)
class DefaultVAESolverBuilder(SolverBuilder):
def build_solver( # type: ignore[override]
self, env_info: EnvironmentInfo, algorithm_config: HyARConfig, **kwargs
) -> nn.solver.Solver:
return NS.Adam(alpha=algorithm_config.vae_learning_rate)
class DefaultExplorerBuilder(ExplorerBuilder):
def build_explorer( # type: ignore[override]
self,
env_info: EnvironmentInfo,
algorithm_config: HyARConfig,
algorithm: "HyAR",
**kwargs,
) -> EnvironmentExplorer:
explorer_config = HyARPolicyExplorerConfig(
warmup_random_steps=0, initial_step_num=algorithm.iteration_num, timelimit_as_terminal=False
)
explorer = HyARPolicyExplorer(
policy_action_selector=algorithm._exploration_action_selector, env_info=env_info, config=explorer_config
)
return explorer
class DefaultPretrainExplorerBuilder(ExplorerBuilder):
def build_explorer( # type: ignore[override]
self,
env_info: EnvironmentInfo,
algorithm_config: HyARConfig,
algorithm: "HyAR",
**kwargs,
) -> EnvironmentExplorer:
explorer_config = HyARPretrainExplorerConfig(
warmup_random_steps=0, initial_step_num=algorithm.iteration_num, timelimit_as_terminal=False
)
explorer = HyARPretrainExplorer(env_info=env_info, config=explorer_config)
return explorer
class DefaultReplayBufferBuilder(ReplayBufferBuilder):
def build_replay_buffer( # type: ignore[override]
self, env_info: EnvironmentInfo, algorithm_config: HyARConfig, **kwargs
) -> ReplayBuffer:
return ReplacementSamplingReplayBuffer(capacity=algorithm_config.replay_buffer_size)
class DefaultVAEBufferBuilder(ReplayBufferBuilder):
def build_replay_buffer( # type: ignore[override]
self, env_info: EnvironmentInfo, algorithm_config: HyARConfig, **kwargs
) -> ReplayBuffer:
return ReplacementSamplingReplayBuffer(capacity=algorithm_config.vae_buffer_size)
[docs]class HyAR(TD3):
"""HyAR algorithm.
This class implements the Hybrid Action Representation (HyAR) algorithm
proposed by Boyan Li, et al.
in the paper: "HyAR: Addressing Discrete-Continuous Action Reinforcement Learning via Hybrid Action Representation"
For details see: https://openreview.net/pdf?id=64trBbOhdGU
Args:
env_or_env_info\
(gym.Env or :py:class:`EnvironmentInfo <nnabla_rl.environments.environment_info.EnvironmentInfo>`):
the environment to train or environment info
config (:py:class:`DQNConfig <nnabla_rl.algorithms.dqn.DQNConfig>`):
the parameter for DQN training
critic_func_builder (:py:class:`ModelBuilder <nnabla_rl.builders.ModelBuilder>`): builder of q function model
critic_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`):
builder of q function solver
actor_func_builder (:py:class:`ModelBuilder <nnabla_rl.builders.ModelBuilder>`): builder of policy model
actor_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`): builder of policy solver
vae_builder (:py:class:`ModelBuilder <nnabla_rl.builders.ModelBuilder>`): builder of vae model
vae_solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`): builder of vae solver
replay_buffer_builder (:py:class:`ReplayBufferBuilder <nnabla_rl.builders.ReplayBufferBuilder>`):
builder of q-function and policy replay_buffer
vae_buffer_builder (:py:class:`ReplayBufferBuilder <nnabla_rl.builders.ReplayBufferBuilder>`):
builder of vae's replay_buffer
explorer_builder (:py:class:`ExplorerBuilder <nnabla_rl.builders.ExplorerBuilder>`):
builder of environment explorer for main training stage
pretrain_explorer_builder (:py:class:`ExplorerBuilder <nnabla_rl.builders.ExplorerBuilder>`):
builder of environment explorer for pretraining stage
"""
# type declarations to type check with mypy
# NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar
# See https://mypy.readthedocs.io/en/stable/class_basics.html for details
_config: HyARConfig
_evaluation_actor: "_HyARPolicyActionSelector" # type: ignore
_exploration_actor: "_HyARPolicyActionSelector" # type: ignore
def __init__(
self,
env_or_env_info,
config: HyARConfig = HyARConfig(),
critic_builder=DefaultCriticBuilder(),
critic_solver_builder=DefaultSolverBuilder(),
actor_builder=DefaultActorBuilder(),
actor_solver_builder=DefaultActorSolverBuilder(),
vae_builder=DefaultVAEBuilder(),
vae_solver_buidler=DefaultVAESolverBuilder(),
replay_buffer_builder=DefaultReplayBufferBuilder(),
vae_buffer_builder=DefaultVAEBufferBuilder(),
explorer_builder=DefaultExplorerBuilder(),
pretrain_explorer_builder=DefaultPretrainExplorerBuilder(),
):
super().__init__(
env_or_env_info,
config,
critic_builder,
critic_solver_builder,
actor_builder,
actor_solver_builder,
replay_buffer_builder,
explorer_builder,
)
with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)):
self._vae = vae_builder("vae", self._env_info, self._config)
self._vae_solver = vae_solver_buidler(self._env_info, self._config)
# We use different replay buffer for vae
self._vae_replay_buffer = vae_buffer_builder(env_info=self._env_info, algorithm_config=self._config)
self._pretrain_explorer_builder = pretrain_explorer_builder
self._evaluation_actor = _HyARPolicyActionSelector(
self._env_info,
self._pi.shallowcopy(),
self._vae.shallowcopy(),
embed_dim=self._config.embed_dim,
latent_dim=self._config.latent_dim,
)
self._exploration_actor = _HyARPolicyActionSelector(
self._env_info,
self._pi.shallowcopy(),
self._vae.shallowcopy(),
embed_dim=self._config.embed_dim,
latent_dim=self._config.latent_dim,
append_noise=True,
sigma=self._config.exploration_noise_sigma,
action_clip_low=-1.0,
action_clip_high=1.0,
)
self._episode_number = 1
self._experienced_episodes = 0
def _before_training_start(self, env_or_buffer):
super()._before_training_start(env_or_buffer)
self._vae_trainer = self._setup_vae_training(env_or_buffer)
self._pretrain_explorer = self._setup_pretrain_explorer(env_or_buffer)
if isinstance(env_or_buffer, gym.Env):
self._pretrain_vae(env_or_buffer)
def _setup_q_function_training(self, env_or_buffer):
# training input/loss variables
q_function_trainer_config = MT.q_value_trainers.HyARQTrainerConfig(
reduction_method="mean",
q_loss_scalar=1.0,
grad_clip=None,
train_action_noise_sigma=self._config.train_action_noise_sigma,
train_action_noise_abs=self._config.train_action_noise_abs,
noisy_action_max=self._config.noisy_action_max,
noisy_action_min=self._config.noisy_action_min,
num_steps=self._config.num_steps,
unroll_steps=self._config.critic_unroll_steps,
burn_in_steps=self._config.critic_burn_in_steps,
reset_on_terminal=self._config.critic_reset_rnn_on_terminal,
embed_dim=self._config.embed_dim,
latent_dim=self._config.latent_dim,
)
q_function_trainer = MT.q_value_trainers.HyARQTrainer(
train_functions=self._train_q_functions,
solvers=self._train_q_solvers,
target_functions=self._target_q_functions,
target_policy=self._target_pi,
vae=self._vae,
env_info=self._env_info,
config=q_function_trainer_config,
)
for q, target_q in zip(self._train_q_functions, self._target_q_functions):
sync_model(q, target_q)
return q_function_trainer
def _setup_policy_training(self, env_or_buffer):
# return super()._setup_policy_training(env_or_buffer)
action_dim = self._config.latent_dim + self._config.embed_dim
policy_trainer_config = MT.policy_trainers.HyARPolicyTrainerConfig(
unroll_steps=self._config.actor_unroll_steps,
burn_in_steps=self._config.actor_burn_in_steps,
reset_on_terminal=self._config.actor_reset_rnn_on_terminal,
p_max=np.ones(shape=(1, action_dim)),
p_min=-np.ones(shape=(1, action_dim)),
)
policy_trainer = MT.policy_trainers.HyARPolicyTrainer(
models=self._pi,
solvers={self._pi.scope_name: self._pi_solver},
q_function=self._q1,
env_info=self._env_info,
config=policy_trainer_config,
)
sync_model(self._pi, self._target_pi, 1.0)
return policy_trainer
def _setup_vae_training(self, env_or_buffer):
vae_trainer_config = MT.encoder_trainers.HyARVAETrainerConfig(
unroll_steps=self._config.critic_unroll_steps,
burn_in_steps=self._config.critic_burn_in_steps,
reset_on_terminal=self._config.critic_reset_rnn_on_terminal,
)
return MT.encoder_trainers.HyARVAETrainer(
self._vae, {self._vae.scope_name: self._vae_solver}, self._env_info, vae_trainer_config
)
def _setup_pretrain_explorer(self, env_or_buffer):
return (
None
if self._is_buffer(env_or_buffer)
else self._pretrain_explorer_builder(self._env_info, self._config, self)
)
def _pretrain_vae(self, env: gym.Env):
for _ in range(self._config.vae_pretrain_episodes):
experiences = self._pretrain_explorer.rollout(env)
self._vae_replay_buffer.append_all(experiences)
for _ in range(self._config.vae_pretrain_times):
self._vae_training(self._vae_replay_buffer, self._config.vae_pretrain_batch_size)
c_rate, ds_rate = self._compute_reconstruction_rate(self._vae_replay_buffer)
self._c_rate = c_rate
self._ds_rate = ds_rate
self._exploration_actor.update_c_rate(c_rate)
self._evaluation_actor.update_c_rate(c_rate)
def _run_online_training_iteration(self, env):
experiences = self._environment_explorer.step(env)
self._replay_buffer.append_all(experiences)
self._vae_replay_buffer.append_all(experiences)
(_, _, _, non_terminal, *_) = experiences[-1]
end_of_episode = non_terminal == 0.0
if end_of_episode:
self._experienced_episodes += 1
if self._experienced_episodes < self._config.noise_decay_steps:
ratio = self._experienced_episodes / self._config.noise_decay_steps
new_sigma = (
self._config.initial_exploration_noise * (1.0 - ratio)
+ self._config.final_exploration_noise * ratio
)
self._exploration_actor.update_sigma(sigma=new_sigma)
else:
self._exploration_actor.update_sigma(sigma=self._config.final_exploration_noise)
if self._config.start_timesteps < self.iteration_num:
self._hyar_training(self._replay_buffer, self._vae_replay_buffer, end_of_episode)
def _run_offline_training_iteration(self, buffer):
raise NotImplementedError
def _hyar_training(self, replay_buffer, vae_replay_buffer, end_of_episode=False):
self._rl_training(replay_buffer)
if (self._experienced_episodes % self._config.T) == 0 and self._iteration_num > 1000 and end_of_episode:
for _ in range(self._config.vae_training_times):
self._vae_training(vae_replay_buffer, self._config.vae_training_batch_size)
c_rate, ds_rate = self._compute_reconstruction_rate(self._vae_replay_buffer)
self._c_rate = c_rate
self._ds_rate = ds_rate
self._exploration_actor.update_c_rate(c_rate)
self._evaluation_actor.update_c_rate(c_rate)
def _rl_training(self, replay_buffer):
actor_steps = self._config.actor_burn_in_steps + self._config.actor_unroll_steps
critic_steps = self._config.num_steps + self._config.critic_burn_in_steps + self._config.critic_unroll_steps - 1
num_steps = max(actor_steps, critic_steps)
experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps)
if num_steps == 1:
experiences_tuple = (experiences_tuple,)
assert len(experiences_tuple) == num_steps
batch = None
for experiences in reversed(experiences_tuple):
(s, a, r, non_terminal, s_next, extra, *_) = marshal_experiences(experiences)
rnn_states = extra["rnn_states"] if "rnn_states" in extra else {}
extra.update({"c_rate": self._c_rate, "ds_rate": self._ds_rate})
batch = TrainingBatch(
batch_size=self._config.batch_size,
s_current=s,
a_current=a,
gamma=self._config.gamma,
reward=r,
non_terminal=non_terminal,
s_next=s_next,
extra=extra,
weight=info["weights"],
next_step_batch=batch,
rnn_states=rnn_states,
)
self._q_function_trainer_state = self._q_function_trainer.train(batch)
td_errors = self._q_function_trainer_state["td_errors"]
replay_buffer.update_priorities(td_errors)
if self.iteration_num % self._config.d == 0:
# Optimize actor
self._policy_trainer_state = self._policy_trainer.train(batch)
# parameter update
for q, target_q in zip(self._train_q_functions, self._target_q_functions):
sync_model(q, target_q, tau=self._config.tau)
sync_model(self._pi, self._target_pi, tau=self._config.tau)
def _vae_training(self, replay_buffer, batch_size):
actor_steps = self._config.actor_burn_in_steps + self._config.actor_unroll_steps
critic_steps = self._config.num_steps + self._config.critic_burn_in_steps + self._config.critic_unroll_steps - 1
num_steps = max(actor_steps, critic_steps)
experiences_tuple, info = replay_buffer.sample(batch_size, num_steps=num_steps)
if num_steps == 1:
experiences_tuple = (experiences_tuple,)
assert len(experiences_tuple) == num_steps
batch = None
for experiences in reversed(experiences_tuple):
(s, a, r, non_terminal, s_next, extra, *_) = marshal_experiences(experiences)
rnn_states = extra["rnn_states"] if "rnn_states" in extra else {}
batch = TrainingBatch(
batch_size=batch_size,
s_current=s,
a_current=a,
gamma=self._config.gamma,
reward=r,
non_terminal=non_terminal,
s_next=s_next,
extra=extra,
weight=info["weights"],
next_step_batch=batch,
rnn_states=rnn_states,
)
self._vae_trainer_state = self._vae_trainer.train(batch)
def _models(self):
models = super()._models()
models.update({self._vae.scope_name: self._vae})
return models
def _solvers(self):
solvers = super()._solvers()
solvers.update({self._vae.scope_name: self._vae_solver})
return solvers
def _compute_reconstruction_rate(self, replay_buffer):
range_rate = 100 - self._config.latent_select_range
batch_size = self._config.latent_select_batch_size
border = int(range_rate * (batch_size / 100))
experiences, _ = replay_buffer.sample(num_samples=batch_size)
(s, a, _, _, s_next, *_) = marshal_experiences(experiences)
if not hasattr(self, "_rate_state_var"):
from nnabla_rl.utils.misc import create_variable
self._rate_state_var = create_variable(batch_size, self._env_info.state_shape)
self._rate_action_var = create_variable(batch_size, self._env_info.action_shape)
self._rate_next_state_var = create_variable(batch_size, self._env_info.state_shape)
action1, action2 = self._rate_action_var
x = action1 if isinstance(self._env_info.action_space[0], gym.spaces.Box) else action2
latent_distribution, (_, predicted_ds) = self._vae.encode_and_decode(
x=x, state=self._rate_state_var, action=self._rate_action_var
)
z = latent_distribution.sample()
# NOTE: ascending order
z_sorted = NF.sort(z, axis=0)
z_up = z_sorted[batch_size - border - 1, :]
z_down = z_sorted[border, :]
z_up.persistent = True
z_down.persistent = True
ds = self._rate_next_state_var - self._rate_state_var
ds_rate = RF.mean_squared_error(ds, predicted_ds)
ds_rate.persistent = True
self._ds_rate_var = ds_rate
self._z_up_var = z_up
self._z_down_var = z_down
set_data_to_variable(self._rate_state_var, s)
set_data_to_variable(self._rate_action_var, a)
set_data_to_variable(self._rate_next_state_var, s_next)
nn.forward_all((self._z_up_var, self._z_down_var, self._ds_rate_var), clear_no_need_grad=True)
return (self._z_up_var.d, self._z_down_var.d), self._ds_rate_var.d
[docs] @classmethod
def is_supported_env(cls, env_or_env_info):
env_info = (
EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info
)
return env_info.is_tuple_action_env() and not env_info.is_tuple_state_env()
[docs] @classmethod
def is_rnn_supported(self):
return False
@property
def latest_iteration_state(self):
latest_iteration_state = super().latest_iteration_state
if hasattr(self, "_vae_trainer_state"):
latest_iteration_state["scalar"].update(
{
"encoder_loss": float(self._vae_trainer_state["encoder_loss"]),
"kl_loss": float(self._vae_trainer_state["kl_loss"]),
"reconstruction_loss": float(self._vae_trainer_state["reconstruction_loss"]),
"dyn_loss": float(self._vae_trainer_state["dyn_loss"]),
}
)
return latest_iteration_state
class _HyARPolicyActionSelector(_ActionSelector[DeterministicPolicy]):
_vae: HyARVAE
def __init__(
self,
env_info: EnvironmentInfo,
model: DeterministicPolicy,
vae: HyARVAE,
embed_dim: int,
latent_dim: int,
append_noise: bool = False,
action_clip_low: float = np.finfo(np.float32).min, # type: ignore
action_clip_high: float = np.finfo(np.float32).max, # type: ignore
sigma: float = 1.0,
):
super().__init__(env_info, model)
self._vae = vae
self._embed_dim = embed_dim
self._latent_dim = latent_dim
self._e: nn.Variable = None
self._z: nn.Variable = None
self._append_noise = append_noise
self._action_clip_low = action_clip_low
self._action_clip_high = action_clip_high
self._sigma = nn.Variable.from_numpy_array(sigma * np.ones(shape=(1, 1)))
# This value is used in the author's code to modify the action
z_up = nn.Variable.from_numpy_array(np.ones(shape=(1, self._latent_dim)))
z_down = nn.Variable.from_numpy_array(-np.ones(shape=(1, self._latent_dim)))
self._c_rate = (z_up, z_down)
def __call__(self, s, *, begin_of_episode=False, extra_info={}):
action, info = super().__call__(s, begin_of_episode=begin_of_episode, extra_info=extra_info)
# Use only the first item in the batch
# self._e.d[0] and self._z.d[0]
e = self._e.d[0]
z = self._z.d[0]
info.update({"e": e, "z": z})
(d_action, c_action) = action
return (d_action, c_action), info
def update_sigma(self, sigma):
self._sigma.d = sigma
def update_c_rate(self, c_rate):
self._c_rate[0].d = c_rate[0]
self._c_rate[1].d = c_rate[1]
def _compute_action(self, state_var: nn.Variable) -> nn.Variable:
latent_action = self._model.pi(state_var)
if self._append_noise:
noise = NF.randn(shape=latent_action.shape)
latent_action = latent_action + noise * self._sigma
latent_action = NF.clip_by_value(latent_action, min=self._action_clip_low, max=self._action_clip_high)
self._e = latent_action[:, : self._embed_dim]
self._e.persistent = True
self._z = latent_action[:, self._embed_dim :]
self._z.persistent = True
assert latent_action.shape[-1] == self._embed_dim + self._latent_dim
d_action = self._vae.decode_discrete_action(self._e)
c_action, _ = self._vae.decode(self._apply_c_rate(self._z), state=state_var, action=(d_action, None))
return d_action, c_action
def _apply_c_rate(self, z):
median = 0.5 * (self._c_rate[0] - self._c_rate[1])
offset = self._c_rate[0] - median
median = NF.reshape(median, shape=(1, -1))
offset = NF.reshape(offset, shape=(1, -1))
z = z * median + offset
return z
class HyARPolicyExplorerConfig(RawPolicyExplorerConfig):
pass
class HyARPolicyExplorer(RawPolicyExplorer):
def _warmup_action(self, env, *, begin_of_episode=False):
return self.action(self._steps, self._state, begin_of_episode=begin_of_episode)
class HyARPretrainExplorerConfig(EnvironmentExplorerConfig):
pass
class HyARPretrainExplorer(EnvironmentExplorer):
def __init__(self, env_info: EnvironmentInfo, config: HyARPretrainExplorerConfig = HyARPretrainExplorerConfig()):
super().__init__(env_info, config)
def action(self, step: int, state, *, begin_of_episode: bool = False):
(d_action, c_action), action_info = self._sample_action(self._env_info)
return (d_action, c_action), action_info
def _warmup_action(self, env, *, begin_of_episode=False):
(d_action, c_action), action_info = self._sample_action(self._env_info)
return (d_action, c_action), action_info
def _sample_action(self, env_info):
action_info = {}
if env_info.is_tuple_action_env():
action = []
for a, action_space in zip(env_info.action_space.sample(), env_info.action_space):
if isinstance(action_space, gym.spaces.Discrete):
a = np.asarray(a).reshape((1,))
action.append(a)
action = tuple(action)
else:
if env_info.is_discrete_action_env():
action = env_info.action_space.sample()
action = np.asarray(action).reshape((1,))
else:
action = env_info.action_space.sample()
return action, action_info