# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union, cast
import gym
import numpy as np
from nnabla_rl.configuration import Configuration
from nnabla_rl.environments.environment_info import EnvironmentInfo
from nnabla_rl.typing import Action, Experience, State
[docs]@dataclass
class EnvironmentExplorerConfig(Configuration):
warmup_random_steps: int = 0
reward_scalar: float = 1.0
timelimit_as_terminal: bool = True
initial_step_num: int = 0
[docs]class EnvironmentExplorer(metaclass=ABCMeta):
"""Base class for environment exploration methods."""
# type declarations to type check with mypy
# NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar
# See https://mypy.readthedocs.io/en/stable/class_basics.html for details
_env_info: EnvironmentInfo
_config: EnvironmentExplorerConfig
_state: Union[State, None]
_action: Union[Action, None]
_next_state: Union[State, None]
_steps: int
def __init__(self, env_info: EnvironmentInfo, config: EnvironmentExplorerConfig = EnvironmentExplorerConfig()):
self._env_info = env_info
self._config = config
self._state = None
self._action = None
self._next_state = None
self._begin_of_episode = True
self._steps = self._config.initial_step_num
[docs] @abstractmethod
def action(self, steps: int, state: np.ndarray, *, begin_of_episode: bool = False) -> Tuple[np.ndarray, Dict]:
"""Compute the action for given state at given timestep.
Args:
steps(int): timesteps since the beginning of exploration
state(np.ndarray): current state to compute the action
begin_of_episode(bool): Informs the beginning of episode. Used for rnn state reset.
Returns:
np.ndarray: action for current state at given timestep
"""
raise NotImplementedError
[docs] def step(self, env: gym.Env, n: int = 1, break_if_done: bool = False) -> List[Experience]:
"""Step n timesteps in given env.
Args:
env(gym.Env): Environment
n(int): Number of timesteps to act in the environment
Returns:
List[Experience]: List of experience.
Experience consists of (state, action, reward, terminal flag, next state and extra info).
"""
assert 0 < n
experiences = []
if self._state is None:
self._state = cast(State, env.reset())
for _ in range(n):
experience, done = self._step_once(env, begin_of_episode=self._begin_of_episode)
experiences.append(experience)
self._begin_of_episode = done
if done and break_if_done:
break
return experiences
[docs] def rollout(self, env: gym.Env) -> List[Experience]:
"""Rollout the episode in current env.
Args:
env(gym.Env): Environment
Returns:
List[Experience]: List of experience.
Experience consists of (state, action, reward, terminal flag, next state and extra info).
"""
self._state = cast(State, env.reset())
done = False
experiences = []
while not done:
experience, done = self._step_once(env, begin_of_episode=self._begin_of_episode)
experiences.append(experience)
self._begin_of_episode = done
return experiences
def _step_once(self, env, *, begin_of_episode=False) -> Tuple[Experience, bool]:
self._steps += 1
if self._steps < self._config.warmup_random_steps:
self._action, action_info = self._warmup_action(env, begin_of_episode=begin_of_episode)
else:
self._action, action_info = self.action(
self._steps, cast(np.ndarray, self._state), begin_of_episode=begin_of_episode
)
self._next_state, r, done, step_info = env.step(self._action)
timelimit = step_info.get("TimeLimit.truncated", False)
if _is_end_of_episode(done, timelimit, self._config.timelimit_as_terminal):
non_terminal = 0.0
else:
non_terminal = 1.0
extra_info: Dict[str, Any] = {}
extra_info.update(action_info)
extra_info.update(step_info)
experience = (
cast(np.ndarray, self._state),
cast(np.ndarray, self._action),
r * self._config.reward_scalar,
non_terminal,
cast(np.ndarray, self._next_state),
extra_info,
)
if done:
self._state = env.reset()
else:
self._state = self._next_state
return experience, done
def _warmup_action(self, env, *, begin_of_episode=False):
return _sample_action(env, self._env_info)
def _is_end_of_episode(done, timelimit, timelimit_as_terminal):
if not done:
return False
else:
return (not timelimit) or (timelimit and timelimit_as_terminal)
def _sample_action(env, env_info):
action_info: Dict[str, Any] = {}
if env_info.is_tuple_action_env():
action = []
for a, action_space in zip(env.action_space.sample(), env_info.action_space):
if isinstance(action_space, gym.spaces.Discrete):
a = np.asarray(a).reshape((1,))
action.append(a)
action = tuple(action)
else:
if env_info.is_discrete_action_env():
action = env.action_space.sample()
action = np.asarray(action).reshape((1,))
else:
action = env.action_space.sample()
return action, action_info