Source code for nnabla_rl.algorithms.decision_transformer
# Copyright 2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union, cast
import gym
import numpy as np
import nnabla as nn
import nnabla.solvers as NS
import nnabla_rl.model_trainers as MT
from nnabla_rl.algorithm import Algorithm, AlgorithmConfig, eval_api
from nnabla_rl.algorithms.common_utils import _DecisionTransformerActionSelector
from nnabla_rl.builders import LearningRateSchedulerBuilder, ModelBuilder, SolverBuilder
from nnabla_rl.environments.environment_info import EnvironmentInfo
from nnabla_rl.model_trainers.model_trainer import ModelTrainer, TrainingBatch
from nnabla_rl.models import AtariDecisionTransformer, DeterministicDecisionTransformer, StochasticDecisionTransformer
from nnabla_rl.replay_buffers import TrajectoryReplayBuffer
from nnabla_rl.replay_buffers.buffer_iterator import BufferIterator
from nnabla_rl.utils import context
from nnabla_rl.utils.data import marshal_experiences
from nnabla_rl.utils.solver_wrappers import AutoClipGradByNorm
DecisionTransformerModel = Union[StochasticDecisionTransformer, DeterministicDecisionTransformer]
[docs]@dataclass
class DecisionTransformerConfig(AlgorithmConfig):
"""List of configurations for DecisionTransformer algorithm.
Args:
learning_rate (float): learning rate which is set to all solvers. \
You can customize/override the learning rate for each solver by implementing the \
(:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`) by yourself. \
Defaults to 0.0006.
batch_size (int): training batch size. Defaults to 128.
context_length (int): Context length of transformer model. Defaults to 30.
max_timesteps (Optional[int]): Optional. Maximum timestep of training environmet.
If the value is not provided, the algorithm will guess the maximum episode length through EnvironmentInfo.
grad_clip_norm (float): Gradient clipping threshold for default solver. Defaults to 1.0.
weight_decay (float): Weight decay parameter for default solver. Defaults to 0.1.
target_return (int): Initial target return used to compute the evaluation action. Defaults to 90.
reward_scale (float): Reward scaler. Reward received during evaluation will be multiplied by this value.
"""
learning_rate: float = 6.0e-4
batch_size: int = 128
context_length: int = 30
max_timesteps: Optional[int] = None
grad_clip_norm: float = 1.0
weight_decay: float = 0.1
target_return: int = 90
reward_scale: float = 1.0
def __post_init__(self):
"""__post_init__
Check set values are in valid range.
"""
self._assert_positive(self.learning_rate, "learning_rate")
self._assert_positive(self.batch_size, "batch_size")
self._assert_positive(self.context_length, "context_length")
self._assert_positive(self.grad_clip_norm, "grad_clip_norm")
if self.max_timesteps is not None:
self._assert_positive(self.max_timesteps, "max_timesteps")
self._assert_positive(self.weight_decay, "weight_decay")
self._assert_positive(self.target_return, "target_return")
class DefaultTransformerBuilder(ModelBuilder[DecisionTransformerModel]):
def build_model( # type: ignore[override]
self,
scope_name: str,
env_info: EnvironmentInfo,
algorithm_config: DecisionTransformerConfig,
**kwargs,
) -> DecisionTransformerModel:
max_timesteps = cast(int, kwargs["max_timesteps"])
return AtariDecisionTransformer(
scope_name, env_info.action_dim, max_timestep=max_timesteps, context_length=algorithm_config.context_length
)
class DefaultSolverBuilder(SolverBuilder):
def build_solver( # type: ignore[override]
self, env_info: EnvironmentInfo, algorithm_config: DecisionTransformerConfig, **kwargs
) -> nn.solver.Solver:
solver = NS.Adam(alpha=algorithm_config.learning_rate, beta1=0.9, beta2=0.95)
return AutoClipGradByNorm(solver, algorithm_config.grad_clip_norm)
[docs]class DecisionTransformer(Algorithm):
"""DecisionTransformer algorithm.
This class implements the DecisionTransformer algorithm
proposed by L. Chen, et al. in the paper: "Decision Transformer: Reinforcement Learning via Sequence Modeling"
For details see: https://arxiv.org/abs/2106.01345
Args:
env_or_env_info\
(gym.Env or :py:class:`EnvironmentInfo <nnabla_rl.environments.environment_info.EnvironmentInfo>`):
the environment to train or environment info
config (:py:class:`DecisionTransformerConfig <nnabla_rl.algorithms.dqn.DecisionTransformerConfig>`):
the parameter for DecisionTransformer training
transformer_builder (:py:class:`ModelBuilder <nnabla_rl.builders.ModelBuilder>`): builder of transformer model
solver_builder (:py:class:`SolverBuilder <nnabla_rl.builders.SolverBuilder>`): builder of transformer solver
"""
# type declarations to type check with mypy
# NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar
# See https://mypy.readthedocs.io/en/stable/class_basics.html for details
_config: DecisionTransformerConfig
_decision_transformer: DecisionTransformerModel
_transformer_solver: nn.solver.Solver
_decision_transformer_trainer: ModelTrainer
_decision_transformer_trainer_state: Dict[str, Any]
_action_selector: _DecisionTransformerActionSelector
def __init__(
self,
env_or_env_info: Union[gym.Env, EnvironmentInfo],
config: DecisionTransformerConfig = DecisionTransformerConfig(),
transformer_builder: ModelBuilder[DecisionTransformerModel] = DefaultTransformerBuilder(),
transformer_solver_builder: SolverBuilder = DefaultSolverBuilder(),
transformer_wd_solver_builder: Optional[SolverBuilder] = None,
lr_scheduler_builder: Optional[LearningRateSchedulerBuilder] = None,
):
super(DecisionTransformer, self).__init__(env_or_env_info, config=config)
if config.max_timesteps is None:
assert not np.isposinf(self._env_info.max_episode_steps)
self._max_timesteps = self._env_info.max_episode_steps
else:
self._max_timesteps = config.max_timesteps
with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)):
self._decision_transformer = transformer_builder(
scope_name="decision_transformer",
env_info=self._env_info,
algorithm_config=self._config,
max_timesteps=self._max_timesteps,
)
self._decision_transformer_solver = transformer_solver_builder(
env_info=self._env_info, algorithm_config=self._config
)
self._decision_transformer_wd_solver = (
None
if transformer_wd_solver_builder is None
else transformer_wd_solver_builder(env_info=self._env_info, algorithm_config=self._config)
)
self._lr_scheduler = (
None
if lr_scheduler_builder is None
else lr_scheduler_builder(env_info=self._env_info, algorithm_config=self._config)
)
self._action_selector = _DecisionTransformerActionSelector(
self._env_info,
self._decision_transformer.shallowcopy(),
self._max_timesteps,
self._config.context_length,
self._config.target_return,
self._config.reward_scale,
)
@eval_api
def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}):
if "reward" not in extra_info:
raise ValueError(
f"{self.__name__} requires previous reward info in addition to state to compute action."
'use extra_info["reward"]=reward'
)
return self._action_selector(state, begin_of_episode=begin_of_episode, extra_info=extra_info)
def _before_training_start(self, env_or_buffer):
# set context globally to ensure that the training runs on configured gpu
context.set_nnabla_context(self._config.gpu_id)
self._decision_transformer_trainer = self._setup_decision_transformer_training(env_or_buffer)
def _setup_decision_transformer_training(self, env_or_buffer):
if isinstance(self._decision_transformer, DeterministicDecisionTransformer):
trainer_config = MT.dt_trainers.DeterministicDecisionTransformerTrainerConfig(
context_length=self._config.context_length
)
solvers = {self._decision_transformer.scope_name: self._decision_transformer_solver}
wd_solver = self._decision_transformer_wd_solver
wd_solvers = None if wd_solver is None else {self._decision_transformer.scope_name: wd_solver}
decision_transformer_trainer = MT.dt_trainers.DeterministicDecisionTransformerTrainer(
models=self._decision_transformer,
solvers=solvers,
wd_solvers=wd_solvers,
env_info=self._env_info,
config=trainer_config,
)
return decision_transformer_trainer
if isinstance(self._decision_transformer, StochasticDecisionTransformer):
trainer_config = MT.dt_trainers.StochasticDecisionTransformerTrainerConfig(
context_length=self._config.context_length
)
solvers = {self._decision_transformer.scope_name: self._decision_transformer_solver}
wd_solver = self._decision_transformer_wd_solver
wd_solvers = None if wd_solver is None else {self._decision_transformer.scope_name: wd_solver}
decision_transformer_trainer = MT.dt_trainers.StochasticDecisionTransformerTrainer(
models=self._decision_transformer,
solvers=solvers,
wd_solvers=wd_solvers,
env_info=self._env_info,
config=trainer_config,
)
return decision_transformer_trainer
raise NotImplementedError(
"Unknown model type. Model should be either Deterministic/StochasticDecisionTransformer"
)
def _run_online_training_iteration(self, env):
raise NotImplementedError(f"Online training is not supported for {self.__name__}")
def _run_offline_training_iteration(self, buffer):
assert isinstance(buffer, TrajectoryReplayBuffer)
self._decision_transformer_training(buffer, run_epoch=True)
def _decision_transformer_training(self, replay_buffer, run_epoch):
if run_epoch:
buffer_iterator = _TrajectoryBufferIterator(
replay_buffer, self._config.batch_size, self._config.context_length
)
# Run 1 epoch
for trajectories, info in buffer_iterator:
self._decision_transformer_iteration(trajectories, info)
else:
trajectories, info = replay_buffer.sample_trajectories_portion(
self._config.batch_size, self._config.context_length
)
self._decision_transformer_iteration(trajectories, info)
def _decision_transformer_iteration(self, trajectories, info):
batch = None
experiences = []
for trajectory in trajectories:
# trajectory: (s1, a1, r1, t1, s_next1, info1), ..., (sN, aN, rN, tN, s_nextN, infoN)
# -> ([s1...sN] , [a1...aN], [r1...rN] , [t1...tN], [s_next1...s_nextN], [info1...infoN])
marshaled = marshal_experiences(trajectory)
experiences.append(marshaled)
(s, a, _, _, _, extra, *_) = marshal_experiences(experiences)
extra["target"] = a
extra["rtg"] = extra["rtg"] # NOTE: insure that 'rtg' exists
extra["timesteps"] = extra["timesteps"][:, 0:1, :]
batch = TrainingBatch(
batch_size=self._config.batch_size,
s_current=s,
a_current=a,
weight=info["weights"],
next_step_batch=None,
extra=extra,
)
self._decision_transformer_trainer_state = self._decision_transformer_trainer.train(batch)
if self._lr_scheduler is not None:
# iteration_num is dummy
new_learning_rate = self._lr_scheduler.get_learning_rate(self._iteration_num)
self._decision_transformer_trainer.set_learning_rate(new_learning_rate)
def _models(self):
models = {}
models[self._decision_transformer.scope_name] = self._decision_transformer
if self._decision_transformer_wd_solver is not None:
models[f"{self._decision_transformer.scope_name}_wd"] = self._decision_transformer
return models
def _solvers(self):
solvers = {}
solvers[self._decision_transformer.scope_name] = self._decision_transformer_solver
if self._decision_transformer_wd_solver is not None:
solvers[f"{self._decision_transformer.scope_name}_wd"] = self._decision_transformer_wd_solver
return solvers
[docs] @classmethod
def is_supported_env(cls, env_or_env_info):
env_info = (
EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info
)
return (
env_info.is_continuous_action_env() or env_info.is_discrete_action_env()
) and not env_info.is_tuple_action_env()
@property
def latest_iteration_state(self):
latest_iteration_state = super(DecisionTransformer, self).latest_iteration_state
if hasattr(self, "_decision_transformer_trainer_state"):
latest_iteration_state["scalar"].update({"loss": float(self._decision_transformer_trainer_state["loss"])})
return latest_iteration_state
@property
def trainers(self):
return {"decision_transformer": self._decision_transformer_trainer}
class _TrajectoryBufferIterator(BufferIterator):
_replay_buffer: TrajectoryReplayBuffer
def __init__(self, buffer, batch_size, portion_length, shuffle=True, repeat=True):
super().__init__(buffer, batch_size, shuffle, repeat)
self._portion_length = portion_length
def _sample(self, indices):
return self._replay_buffer.sample_indices_portion(indices, portion_length=self._portion_length)