Source code for nnabla_rl.environment_explorer

# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union, cast

import gym
import numpy as np

from nnabla_rl.configuration import Configuration
from nnabla_rl.environments.environment_info import EnvironmentInfo
from nnabla_rl.typing import Action, Experience, State


[docs]@dataclass class EnvironmentExplorerConfig(Configuration): warmup_random_steps: int = 0 reward_scalar: float = 1.0 timelimit_as_terminal: bool = True initial_step_num: int = 0
[docs]class EnvironmentExplorer(metaclass=ABCMeta): """Base class for environment exploration methods.""" # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _env_info: EnvironmentInfo _config: EnvironmentExplorerConfig _state: Union[State, None] _action: Union[Action, None] _next_state: Union[State, None] _steps: int def __init__(self, env_info: EnvironmentInfo, config: EnvironmentExplorerConfig = EnvironmentExplorerConfig()): self._env_info = env_info self._config = config self._state = None self._action = None self._next_state = None self._begin_of_episode = True self._steps = self._config.initial_step_num
[docs] @abstractmethod def action(self, steps: int, state: np.ndarray, *, begin_of_episode: bool = False) -> Tuple[np.ndarray, Dict]: """Compute the action for given state at given timestep. Args: steps(int): timesteps since the beginning of exploration state(np.ndarray): current state to compute the action begin_of_episode(bool): Informs the beginning of episode. Used for rnn state reset. Returns: np.ndarray: action for current state at given timestep """ raise NotImplementedError
[docs] def step(self, env: gym.Env, n: int = 1, break_if_done: bool = False) -> List[Experience]: """Step n timesteps in given env. Args: env(gym.Env): Environment n(int): Number of timesteps to act in the environment Returns: List[Experience]: List of experience. Experience consists of (state, action, reward, terminal flag, next state and extra info). """ assert 0 < n experiences = [] if self._state is None: self._state = cast(State, env.reset()) for _ in range(n): experience, done = self._step_once(env, begin_of_episode=self._begin_of_episode) experiences.append(experience) self._begin_of_episode = done if done and break_if_done: break return experiences
[docs] def rollout(self, env: gym.Env) -> List[Experience]: """Rollout the episode in current env. Args: env(gym.Env): Environment Returns: List[Experience]: List of experience. Experience consists of (state, action, reward, terminal flag, next state and extra info). """ self._state = cast(State, env.reset()) done = False experiences = [] while not done: experience, done = self._step_once(env, begin_of_episode=self._begin_of_episode) experiences.append(experience) self._begin_of_episode = done return experiences
def _step_once(self, env, *, begin_of_episode=False) -> Tuple[Experience, bool]: self._steps += 1 if self._steps < self._config.warmup_random_steps: self._action, action_info = self._warmup_action(env, begin_of_episode=begin_of_episode) else: self._action, action_info = self.action( self._steps, cast(np.ndarray, self._state), begin_of_episode=begin_of_episode ) self._next_state, r, done, step_info = env.step(self._action) timelimit = step_info.get("TimeLimit.truncated", False) if _is_end_of_episode(done, timelimit, self._config.timelimit_as_terminal): non_terminal = 0.0 else: non_terminal = 1.0 extra_info: Dict[str, Any] = {} extra_info.update(action_info) extra_info.update(step_info) experience = ( cast(np.ndarray, self._state), cast(np.ndarray, self._action), r * self._config.reward_scalar, non_terminal, cast(np.ndarray, self._next_state), extra_info, ) if done: self._state = env.reset() else: self._state = self._next_state return experience, done def _warmup_action(self, env, *, begin_of_episode=False): return _sample_action(env, self._env_info)
def _is_end_of_episode(done, timelimit, timelimit_as_terminal): if not done: return False else: return (not timelimit) or (timelimit and timelimit_as_terminal) def _sample_action(env, env_info): action_info: Dict[str, Any] = {} if env_info.is_tuple_action_env(): action = [] for a, action_space in zip(env.action_space.sample(), env_info.action_space): if isinstance(action_space, gym.spaces.Discrete): a = np.asarray(a).reshape((1,)) action.append(a) action = tuple(action) else: if env_info.is_discrete_action_env(): action = env.action_space.sample() action = np.asarray(action).reshape((1,)) else: action = env.action_space.sample() return action, action_info