# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
from typing import Optional, Sequence, Tuple, cast
import numpy as np
from nnabla_rl.replay_buffer import ReplayBuffer
from nnabla_rl.replay_buffers.prioritized_replay_buffer import (
ProportionalPrioritizedReplayBuffer,
RankBasedPrioritizedReplayBuffer,
)
from nnabla_rl.replay_buffers.trajectory_replay_buffer import TrajectoryReplayBuffer
from nnabla_rl.typing import Trajectory
from nnabla_rl.utils.data import RingBuffer
[docs]class MemoryEfficientAtariBuffer(ReplayBuffer):
"""Buffer designed to compactly save experiences of Atari environments used
in DQN.
DQN (and other training algorithms) requires large replay buffer
when training on Atari games. If you naively save the experiences,
you'll need more than 100GB to save them (assuming 1M experiences).
Which usually does not fit in the machine's memory (unless you have
money:). This replay buffer reduces the size of experience by
casting the images to uint8 and removing old frames concatenated to
the observation. By using this buffer, you can hold 1M experiences
using only 20GB(approx.) of memory. Note that this class is designed
only for DQN style training on atari environment. (i.e. State
consists of "stacked_frames" number of concatenated grayscaled
frames and its values are normalized between 0 and 1)
"""
# type declarations to type check with mypy
# NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar
# See https://mypy.readthedocs.io/en/stable/class_basics.html for details
_buffer: RingBuffer
_sub_buffer: deque
def __init__(self, capacity: int, stacked_frames: int = 4):
super(MemoryEfficientAtariBuffer, self).__init__(capacity=capacity)
self._reset = True
self._buffer = RingBuffer(maxlen=capacity)
self._sub_buffer = deque(maxlen=stacked_frames - 1)
self._stacked_frames = stacked_frames
[docs] def append(self, experience):
self._reset = _append_to_buffer(experience, self._buffer, self._sub_buffer, self._reset)
def __getitem__(self, index: int):
return _getitem_from_buffer(index, self._buffer, self._sub_buffer, self._stacked_frames)
class _LazyAtariTrajectory(object):
def __init__(self, buffer: MemoryEfficientAtariBuffer):
self._buffer = buffer
def __len__(self):
return len(self._buffer)
def __getitem__(self, key):
if isinstance(key, slice):
return [self._buffer[i] for i in range(key.stop)[key]]
elif isinstance(key, int):
return self._buffer[key]
else:
raise TypeError("Invalid key type")
[docs]class MemoryEfficientAtariTrajectoryBuffer(TrajectoryReplayBuffer):
def __init__(self, num_trajectories=None):
super(MemoryEfficientAtariTrajectoryBuffer, self).__init__(num_trajectories)
def append_trajectory(self, trajectory: Trajectory):
# Use memory efficient atari buffer to save the trajectory efficiently
atari_buffer = MemoryEfficientAtariBuffer(capacity=len(trajectory))
atari_buffer.append_all(trajectory)
self._buffer.append(atari_buffer)
# Below is the same as super class' code
self._samples_per_trajectory.append(len(trajectory))
num_experiences = 0
cumsum_experiences = []
for i in range(self.trajectory_num):
num_experiences += self._samples_per_trajectory[i]
cumsum_experiences.append(num_experiences)
self._num_experiences = num_experiences
self._cumsum_experiences = cumsum_experiences
def get_trajectory(self, trajectory_index: int) -> Trajectory:
return self._buffer_to_trajectory(self._get_atari_buffer(trajectory_index))
[docs] def sample(self, num_samples: int = 1, num_steps: int = 1):
raise NotImplementedError
[docs] def sample_indices(self, indices: Sequence[int], num_steps: int = 1):
raise NotImplementedError
def _buffer_to_trajectory(self, buffer: MemoryEfficientAtariBuffer) -> Trajectory:
return cast(Trajectory, _LazyAtariTrajectory(buffer))
def _get_atari_buffer(self, trajectory_index: int) -> MemoryEfficientAtariBuffer:
return cast(MemoryEfficientAtariBuffer, self._buffer[trajectory_index])
class ProportionalPrioritizedAtariBuffer(ProportionalPrioritizedReplayBuffer):
"""Prioritized buffer designed to compactly save experiences of Atari
environments used in DQN.
Proportional Prioritized version of efficient Atari buffer. Note
that this class is designed only for DQN style training on atari
environment. (i.e. State consists of "stacked_frames" number of
concatenated grayscaled frames and its values are normalized between
0 and 1)
"""
# type declarations to type check with mypy
# NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar
# See https://mypy.readthedocs.io/en/stable/class_basics.html for details
_sub_buffer: deque
def __init__(
self,
capacity: int,
alpha: float = 0.6,
beta: float = 0.4,
betasteps: int = 50000000,
error_clip: Optional[Tuple[float, float]] = (-1, 1),
epsilon: float = 1e-8,
normalization_method: str = "buffer_max",
stacked_frames: int = 4,
):
super(ProportionalPrioritizedAtariBuffer, self).__init__(
capacity=capacity,
alpha=alpha,
beta=beta,
betasteps=betasteps,
error_clip=error_clip,
epsilon=epsilon,
normalization_method=normalization_method,
)
self._reset = True
self._sub_buffer = deque(maxlen=stacked_frames - 1)
self._stacked_frames = stacked_frames
def append(self, experience):
self._reset = _append_to_buffer(experience, self._buffer, self._sub_buffer, self._reset)
def __getitem__(self, index: int):
return _getitem_from_buffer(index, self._buffer, self._sub_buffer, self._stacked_frames)
class RankBasedPrioritizedAtariBuffer(RankBasedPrioritizedReplayBuffer):
"""Prioritized buffer designed to compactly save experiences of Atari
environments used in DQN.
RankBased Prioritized version of efficient Atari buffer. Note that
this class is designed only for DQN style training on atari
environment. (i.e. State consists of "stacked_frames" number of
concatenated grayscaled frames and its values are normalized between
0 and 1)
"""
# type declarations to type check with mypy
# NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar
# See https://mypy.readthedocs.io/en/stable/class_basics.html for details
_sub_buffer: deque
def __init__(
self,
capacity: int,
alpha: float = 0.7,
beta: float = 0.5,
betasteps: int = 50000000,
error_clip: Optional[Tuple[float, float]] = (-1, 1),
reset_segment_interval: int = 1000,
sort_interval: int = 1000000,
stacked_frames: int = 4,
):
super(RankBasedPrioritizedAtariBuffer, self).__init__(
capacity=capacity,
alpha=alpha,
beta=beta,
betasteps=betasteps,
error_clip=error_clip,
reset_segment_interval=reset_segment_interval,
sort_interval=sort_interval,
)
self._reset = True
self._sub_buffer = deque(maxlen=stacked_frames - 1)
self._stacked_frames = stacked_frames
def append(self, experience):
self._reset = _append_to_buffer(experience, self._buffer, self._sub_buffer, self._reset)
def __getitem__(self, index: int):
return _getitem_from_buffer(index, self._buffer, self._sub_buffer, self._stacked_frames)
def _denormalize_state(state, scalar=255.0):
return (state * scalar).astype(np.uint8)
def _normalize_state(state, scalar=255.0):
return state.astype(np.float32) / scalar
def _is_float(state):
return np.issubdtype(state.dtype, np.floating)
def _append_to_buffer(experience, buffer, sub_buffer, reset_flag):
s, a, r, non_terminal, s_next, info, *_ = experience
if s.shape != (84, 84):
# Use only the last image to reduce memory
s = s[-1]
s_next = s_next[-1]
if _is_float(s):
s = _denormalize_state(s)
if _is_float(s_next):
s_next = _denormalize_state(s_next)
assert s.shape == (84, 84)
assert s.shape == s_next.shape
experience = (s, a, r, non_terminal, s_next, info, reset_flag)
removed = buffer.append_with_removed_item_check(experience)
if removed is not None:
sub_buffer.append(removed)
return 0 == non_terminal
def _getitem_from_buffer(index, buffer, sub_buffer, stacked_frames):
(_, a, r, non_terminal, s_next, info, _) = buffer[index]
states = np.zeros(shape=(stacked_frames, 84, 84), dtype=np.uint8)
for i in range(0, stacked_frames):
buffer_index = index - i
if 0 <= buffer_index:
(s, _, _, _, _, _, reset) = buffer[buffer_index]
else:
(s, _, _, _, _, _, reset) = sub_buffer[buffer_index]
assert s.shape == (84, 84)
tail_index = stacked_frames - i
if reset:
states[0:tail_index] = s
break
else:
states[tail_index - 1] = s
s = _normalize_state(states)
assert s.shape == (stacked_frames, 84, 84)
s_next = np.expand_dims(s_next, axis=0)
s_next = _normalize_state(s_next)
if 1 < stacked_frames:
s_next = np.concatenate((s[1:], s_next), axis=0)
assert s.shape == s_next.shape
return (s, a, r, non_terminal, s_next, info)