# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional
import gym
from nnabla_rl.environments.gym_utils import (
extract_max_episode_steps,
get_space_dim,
get_space_high,
get_space_low,
get_space_shape,
is_same_space_type,
)
from nnabla_rl.external.goal_env import GoalEnv
[docs]@dataclass
class EnvironmentInfo(object):
"""Environment Information class.
This class contains the basic information of the target training
environment.
"""
observation_space: gym.spaces.Space
action_space: gym.spaces.Space
max_episode_steps: int
def __init__(
self,
observation_space,
action_space,
max_episode_steps,
unwrapped_env,
reward_function: Optional[Callable[[Any, Any, Dict], int]] = None,
):
self.observation_space = observation_space
self.action_space = action_space
self.max_episode_steps = max_episode_steps
self.unwrapped_env = unwrapped_env
self.reward_function = reward_function
if not (self.is_discrete_state_env() or self.is_continuous_state_env() or self.is_tuple_state_env()):
raise ValueError(f"Unsupported state space: {observation_space}")
if not (self.is_discrete_action_env() or self.is_continuous_action_env() or self.is_tuple_action_env()):
raise ValueError(f"Unsupported action space: {action_space}")
[docs] @staticmethod
def from_env(env):
"""Create env_info from environment.
Args:
env (gym.Env): the environment
Returns:
EnvironmentInfo\
(:py:class:`EnvironmentInfo <nnabla_rl.environments.environment_info.EnvironmentInfo>`)
Example:
>>> import gym
>>> from nnabla_rl.environments.environment_info import EnvironmentInfo
>>> env = gym.make("CartPole-v0")
>>> env_info = EnvironmentInfo.from_env(env)
>>> env_info.state_shape
(4,)
"""
reward_function = env.compute_reward if hasattr(env, "compute_reward") else None
unwrapped_env = env.unwrapped
return EnvironmentInfo(
observation_space=env.observation_space,
action_space=env.action_space,
max_episode_steps=extract_max_episode_steps(env),
unwrapped_env=unwrapped_env,
reward_function=reward_function,
)
[docs] def is_discrete_action_env(self):
"""Check whether the action to execute in the environment is discrete
or not.
Returns:
bool: True if the action to execute in the environment is discrete. Otherwise False.
Note that if the action is gym.spaces.Tuple and all of the element are discrete, it returns True.
"""
return is_same_space_type(self.action_space, gym.spaces.Discrete)
[docs] def is_continuous_action_env(self):
"""Check whether the action to execute in the environment is continuous
or not.
Returns:
bool: True if the action to execute in the environment is continuous. Otherwise False.
Note that if the action is gym.spaces.Tuple and all of the element are continuous, it returns True.
"""
return is_same_space_type(self.action_space, gym.spaces.Box)
[docs] def is_mixed_action_env(self):
"""Check whether the action of the environment consists of either
continuous or discrete action.
Returns:
bool: True if the action of the environment is either continuous or discrete. Otherwise False.
Note that if the action is not a gym.spaces.Tuple, then returns False.
"""
if not self.is_tuple_action_env():
return False
return all(isinstance(a, gym.spaces.Discrete) or isinstance(a, gym.spaces.Box) for a in self.action_space)
[docs] def is_tuple_action_env(self):
"""Check whether the action of the environment is tuple or not.
Returns:
bool: True if the action of the environment is tuple. Otherwise False.
"""
return isinstance(self.action_space, gym.spaces.Tuple)
[docs] def is_discrete_state_env(self):
"""Check whether the state of the environment is discrete or not.
Returns:
bool: True if the state of the environment is discrete. Otherwise False.
Note that if the state is gym.spaces.Tuple and all of the element are discrete, it returns True.
"""
return is_same_space_type(self.observation_space, gym.spaces.Discrete)
[docs] def is_continuous_state_env(self):
"""Check whether the state of the environment is continuous or not.
Returns:
bool: True if the state of the environment is continuous. Otherwise False.
Note that if the state is gym.spaces.Tuple and all of the element are continuous, it returns True.
"""
return is_same_space_type(self.observation_space, gym.spaces.Box)
[docs] def is_mixed_state_env(self):
"""Check whether the state of the environment consists of either
continuous or discrete state.
Returns:
bool: True if the state of the environment is either continuous or discrete. Otherwise False.
Note that if the state is not a gym.spaces.Tuple, then returns False.
"""
if not self.is_tuple_state_env():
return False
return all(isinstance(s, gym.spaces.Discrete) or isinstance(s, gym.spaces.Box) for s in self.observation_space)
[docs] def is_tuple_state_env(self):
"""Check whether the state of the environment is tuple or not.
Returns:
bool: True if the state of the environment is tuple. Otherwise False.
"""
return isinstance(self.observation_space, gym.spaces.Tuple)
[docs] def is_goal_conditioned_env(self):
"""Check whether the environment is gym.GoalEnv or not.
Returns:
bool: True if the environment is GoalEnv. Otherwise False.
"""
return issubclass(self.unwrapped_env.__class__, GoalEnv)
@property
def state_shape(self):
"""The shape of observation space."""
if self.is_tuple_state_env():
return tuple(map(get_space_shape, self.observation_space))
else:
return get_space_shape(self.observation_space)
@property
def state_dim(self):
"""The dimension of state assuming that the state is flatten."""
if self.is_tuple_state_env():
return tuple(map(get_space_dim, self.observation_space))
else:
return get_space_dim(self.observation_space)
@property
def state_high(self):
"""The upper limit of observation space."""
if self.is_tuple_state_env():
return tuple(map(get_space_high, self.observation_space))
else:
return get_space_high(self.observation_space)
@property
def state_low(self):
"""The lower limit of observation space."""
if self.is_tuple_state_env():
return tuple(map(get_space_low, self.observation_space))
else:
return get_space_low(self.observation_space)
@property
def action_high(self):
"""The upper limit of action space."""
if self.is_tuple_action_env():
return tuple(map(get_space_high, self.action_space))
else:
return get_space_high(self.action_space)
@property
def action_low(self):
"""The lower limit of action space."""
if self.is_tuple_action_env():
return tuple(map(get_space_low, self.action_space))
else:
return get_space_low(self.action_space)
@property
def action_shape(self):
"""The shape of action space."""
if self.is_tuple_action_env():
return tuple(map(get_space_shape, self.action_space))
else:
return get_space_shape(self.action_space)
@property
def action_dim(self):
"""The dimension of action assuming that the action is flatten."""
if self.is_tuple_action_env():
return tuple(map(get_space_dim, self.action_space))
else:
return get_space_dim(self.action_space)