Source code for nnabla_rl.replay_buffers.prioritized_replay_buffer

# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import sys
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, ClassVar, Generic, List, Optional, Sequence, Tuple, TypeVar, Union, cast

import numpy as np

import nnabla_rl as rl
from nnabla_rl.replay_buffer import ReplayBuffer
from nnabla_rl.typing import Experience
from nnabla_rl.utils.data import DataHolder, RingBuffer

T = TypeVar("T")


# NOTE: index naming convention used in this module
# relative index: 0: oldest item's index. capacity - 1: newest item's index.
# absolute index: actual data index in list. 0: list's head. capacity - 1: list's tail.
# tree index: 0: root of the tree. 2 * capacity - 1: right most leaf of the tree.
# heap index: 0: head of the heap. If max heap, maximum value is saved in this index. capacity - 1: tail of the heap.


@dataclass
class Node(Generic[T]):
    value: T
    parent: int = -1
    left: int = 1
    right: int = 2


class BinaryTree(Generic[T]):
    """Common Binary Tree Class SumTree and MinTree is derived from this class.

    Args:
        capacity (int): the maximum number of saved data.
        init_node_value (T): the initial value of node.
    """

    # type declarations to type check with mypy
    # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar
    # See https://mypy.readthedocs.io/en/stable/class_basics.html for details
    _tree: List[Node[T]]
    _tail_index: int

    def __init__(self, capacity: int, init_node_value: T):
        self._capacity = capacity
        self._init_node_value = init_node_value
        self._tail_index = 0
        self._length = 0
        self._tree = [self._make_init_node(i) for i in range(2 * capacity - 1)]

    def __len__(self):
        return self._length

    def __getitem__(self, tree_index: int):
        return self._tree[tree_index].value

    def append(self, value: T):
        self.update(self._tail_index, value)
        self._tail_index = (self._tail_index + 1) % self._capacity
        if self._length < self._capacity:
            self._length += 1

    def update(self, absolute_index: int, value: T):
        tree_index = self.absolute_to_tree_index(absolute_index)
        self._tree[tree_index].value = value
        self._update_parent(tree_index)

    @abstractmethod
    def _update_parent(self, tree_index: int):
        raise NotImplementedError

    def tree_to_absolute_index(self, tree_index: int):
        return tree_index - (self._capacity - 1)

    def absolute_to_tree_index(self, absolute_index: int):
        return absolute_index + self._capacity - 1

    def _make_init_node(self, index: int):
        parent = (index - 1) // 2
        left = 2 * index + 1 if index < self._capacity - 1 else -1
        right = left + 1 if index < self._capacity - 1 else -1
        value = self._init_node_value
        return Node(value=value, parent=parent, left=left, right=right)


class MinTree(BinaryTree[float]):
    def __init__(self, capacity: int):
        super(MinTree, self).__init__(capacity, init_node_value=math.inf)

    def min(self):
        return self._tree[0].value

    def _update_parent(self, tree_index: int):
        if tree_index > 0:
            parent_index = self._tree[tree_index].parent
            left_index = self._tree[parent_index].left
            left_value = self._tree[left_index].value
            right_index = self._tree[parent_index].right
            right_value = self._tree[right_index].value
            self._tree[parent_index].value = min(left_value, right_value)
            self._update_parent(parent_index)


class SumTree(BinaryTree[float]):
    def __init__(self, capacity: int):
        super(SumTree, self).__init__(capacity, init_node_value=0.0)

    def get_absolute_index_from_query(self, query: float):
        """Sample absolute index from query value."""
        if query < 0 or query > self.sum():
            raise ValueError(f"You must use value between [0, {self.sum()}] as query")
        node = self._tree[0]
        while node.left >= 0:
            left_value = self._tree[node.left].value
            if query < left_value:
                tree_index = node.left
            else:
                tree_index = node.right
                query -= left_value
            node = self._tree[tree_index]
        return self.tree_to_absolute_index(tree_index)

    def sum(self):
        return self._tree[0].value

    def _update_parent(self, tree_index: int):
        if tree_index > 0:
            parent_index = self._tree[tree_index].parent
            left_index = self._tree[parent_index].left
            left_value = self._tree[left_index].value
            right_index = self._tree[parent_index].right
            right_value = self._tree[right_index].value
            self._tree[parent_index].value = left_value + right_value
            self._update_parent(parent_index)


class MaxHeap(object):
    def __init__(self, capacity):
        self._capacity = capacity
        self._heap = [None for _ in range(capacity)]
        self._heap_to_absolute_index_map = [None for _ in range(capacity)]
        self._absolute_to_heap_index_map = [None for _ in range(capacity)]

        self._tail_index = 0
        self._oldest_index = 0
        self._length = 0

    def __len__(self):
        return self._length

    def __getitem__(self, heap_index: int):
        return self._heap[heap_index]

    def append(self, value: float):
        if len(self) == self._capacity:
            # remove the oldest and replace with new data
            # Reset the priority of oldest_index data to maximum
            # We know that new data will be inserted there
            self.update(self._oldest_index, value)
            self._oldest_index = (self._oldest_index + 1) % self._capacity
        else:
            self._heappush(self._tail_index, value)
            if self._tail_index < self._capacity - 1:
                self._tail_index += 1
            self._length += 1

    def sort_data(self):
        # Decreasing order
        self._heap = sorted(self._heap, key=lambda item: -math.inf if item is None else item[1], reverse=True)

        # Reset index map
        for index, item in enumerate(self._heap):
            if item is not None:
                self._heap_to_absolute_index_map[index] = item[0]
                self._absolute_to_heap_index_map[item[0]] = index
            else:
                self._heap_to_absolute_index_map[index] = None

    def get_absolute_index_from_heap_index(self, heap_index: int):
        return self.heap_to_absolute_index(heap_index)

    def update(self, absolute_index: int, value: float):
        heap_index = self.absolute_to_heap_index(absolute_index)
        (absolute_index, _) = self._heap[heap_index]
        self._heap[heap_index] = (absolute_index, value)
        self._heapup(heap_index)
        self._heapdown(heap_index)

    def _parent_index(self, child_index):
        return (child_index - 1) // 2

    def _heappush(self, absolute_index, error):
        heap_index = self._tail_index
        self._heap_to_absolute_index_map[heap_index] = absolute_index
        self._absolute_to_heap_index_map[absolute_index] = heap_index
        self._heap[heap_index] = (absolute_index, error)
        self._heapup(heap_index)

    def _heapup(self, heap_index):
        if heap_index == 0:
            return
        heap_data = self._heap[heap_index]
        parent_index = self._parent_index(heap_index)
        parent_data = self._heap[parent_index]
        if parent_data[1] < heap_data[1]:
            self._swap_item(heap_index, parent_index)
            self._heapup(parent_index)

    def _heapdown(self, heap_index):
        heap_length = len(self)
        if heap_length <= heap_index:
            return
        heap_data = self._heap[heap_index]
        child_l_index = heap_index * 2 + 1
        child_r_index = heap_index * 2 + 2
        child_l_data = self._heap[child_l_index] if child_l_index < self._capacity else None
        child_r_data = self._heap[child_r_index] if child_r_index < self._capacity else None

        largest_data_index = heap_index
        if child_l_data is not None:
            if (child_l_index < heap_length) and (child_l_data[1] > heap_data[1]):
                largest_data_index = child_l_index
        if child_r_data is not None:
            if (child_r_index < heap_length) and (child_r_data[1] > self._heap[largest_data_index][1]):
                largest_data_index = child_r_index
        if largest_data_index != heap_index:
            self._swap_item(heap_index, largest_data_index)
            self._heapdown(largest_data_index)

    def _swap_item(self, heap_index1, heap_index2):
        heap_index1_data = self._heap[heap_index1]
        heap_index2_data = self._heap[heap_index2]
        self._heap[heap_index1], self._heap[heap_index2] = heap_index2_data, heap_index1_data
        self._heap_to_absolute_index_map[heap_index1] = heap_index2_data[0]
        self._absolute_to_heap_index_map[heap_index2_data[0]] = heap_index1
        self._heap_to_absolute_index_map[heap_index2] = heap_index1_data[0]
        self._absolute_to_heap_index_map[heap_index1_data[0]] = heap_index2

    def absolute_to_heap_index(self, absolute_index):
        return self._absolute_to_heap_index_map[absolute_index]

    def heap_to_absolute_index(self, heap_index):
        return self._heap_to_absolute_index_map[heap_index]


class PrioritizedDataHolder(DataHolder[Any]):
    def __init__(self, capacity: int):
        self._capacity = capacity
        self._data = RingBuffer(maxlen=capacity)

    def __len__(self):
        return len(self._data)

    def __getitem__(self, relative_index: int):
        return self._data[relative_index]

    def append(self, data):
        # ignore returned value
        self.append_with_removed_item_check(data)

    def append_with_removed_item_check(self, data):
        raise NotImplementedError

    def update_priority(self, relative_index: int, priority: int):
        raise NotImplementedError

    def get_priority(self, relative_index: int):
        raise NotImplementedError

    def _relative_to_absolute_index(self, relative_index):
        return (relative_index + self._data._head) % self._capacity

    def _absolute_to_relative_index(self, absolute_index):
        return (absolute_index - self._data._head) % self._capacity


class SumTreeDataHolder(PrioritizedDataHolder):
    def __init__(self, capacity, initial_max_priority, keep_min=True):
        super().__init__(capacity=capacity)
        self._sum_tree = SumTree(capacity=capacity)
        self._keep_min = keep_min
        if self._keep_min:
            self._min_tree = MinTree(capacity=capacity)
        self._max_priority = initial_max_priority

    def append_with_removed_item_check(self, data):
        removed = self._data.append_with_removed_item_check(data)
        self._sum_tree.append(self._max_priority)
        if self._keep_min:
            self._min_tree.append(self._max_priority)
        return removed

    def get_priority(self, relative_index: int):
        absolute_index = self._relative_to_absolute_index(relative_index)
        tree_index = self._sum_tree.absolute_to_tree_index(absolute_index)
        return self._sum_tree[tree_index]

    def sum_priority(self):
        return self._sum_tree.sum()

    def min_priority(self):
        return self._min_tree.min()

    def update_priority(self, relative_index: int, priority: float):
        absolute_index = self._relative_to_absolute_index(relative_index)
        self._sum_tree.update(absolute_index, priority)
        if self._keep_min:
            self._min_tree.update(absolute_index, priority)
        self._max_priority = max(self._max_priority, priority)

    def get_index_from_query(self, query: float):
        absolute_index = self._sum_tree.get_absolute_index_from_query(query)
        return self._absolute_to_relative_index(absolute_index)


class MaxHeapDataHolder(PrioritizedDataHolder):
    def __init__(self, capacity: int, alpha: float):
        super().__init__(capacity=capacity)
        self._max_heap = MaxHeap(capacity)
        self._alpha = alpha

    def append_with_removed_item_check(self, data):
        removed = self._data.append_with_removed_item_check(data)
        self._max_heap.append(math.inf)
        return removed

    def get_priority(self, relative_index: int):
        absolute_index = self._relative_to_absolute_index(relative_index)
        heap_index = self._max_heap.absolute_to_heap_index(absolute_index)
        rank = heap_index + 1
        return self._compute_priority(rank)

    def get_relative_index_from_heap_index(self, heap_index: int):
        absolute_index = self._max_heap.get_absolute_index_from_heap_index(heap_index)
        return self._absolute_to_relative_index(absolute_index)

    def update_priority(self, relative_index: int, priority: float):
        absolute_index = self._relative_to_absolute_index(relative_index)
        self._max_heap.update(absolute_index, priority)

    def sort_data(self):
        self._max_heap.sort_data()

    def _compute_priority(self, rank: int):
        priority = (1 / rank) ** self._alpha

        # We do not normalize priority here to reduce computation.
        # Normalization term will be compensated when dividing with maximum weight
        return priority


class _PrioritizedReplayBuffer(ReplayBuffer):
    def __init__(
        self, capacity: int, alpha: float, beta: float, betasteps: int, error_clip: Optional[Tuple[float, float]]
    ):
        # Do not call super class' constructor
        self._capacity_check(capacity)
        self._capacity = capacity

        self._alpha = alpha
        self._beta = beta
        self._beta_diff = (1.0 - beta) / betasteps

        self._error_clip = error_clip

        # last absolute indices of experiences sampled from buffer
        self._last_sampled_indices: Union[Sequence[int], None] = None

    def __getitem__(self, relative_index: int):
        # NOTE: relative index 0 means the oldest entry and len(self) - 1 the latest entry
        return self._buffer[relative_index]

    def __len__(self):
        return len(self._buffer)

    def sample(self, num_samples: int = 1, num_steps: int = 1):
        raise NotImplementedError

    def sample_indices(self, indices: Sequence[int], num_steps: int = 1):
        if len(indices) == 0:
            raise ValueError("Indices are empty")
        if self._last_sampled_indices is not None:
            raise RuntimeError(
                "Trying to sample data from buffer without updating priority. "
                "Check that the algorithm supports prioritized replay buffer."
            )
        experiences: Union[Sequence[Experience], Tuple[Sequence[Experience], ...]]
        if num_steps == 1:
            experiences = [self.__getitem__(index) for index in indices]
        else:
            experiences = tuple([self.__getitem__(index + i) for index in indices] for i in range(num_steps))

        weights = self._get_weights(indices, self._alpha, self._beta)
        info = dict(weights=weights)

        self._beta = min(self._beta + self._beta_diff, 1.0)
        self._last_sampled_indices = indices
        return experiences, info

    def update_priorities(self, errors: np.ndarray):
        raise NotImplementedError

    def _preprocess_errors(self, errors: np.ndarray):
        if self._error_clip is not None:
            errors = np.clip(errors, self._error_clip[0], self._error_clip[1])
        return np.abs(errors)

    def _get_weights(self, indices: Sequence[int], alpha: float, beta: float):
        raise NotImplementedError

    def _capacity_check(self, capacity: int):
        if capacity is None or capacity <= 0:
            error_msg = "buffer size must be greater than 0"
            raise ValueError(error_msg)


class ProportionalPrioritizedReplayBuffer(_PrioritizedReplayBuffer):
    # type declarations to type check with mypy
    # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar
    # See https://mypy.readthedocs.io/en/stable/class_basics.html for details
    _buffer: SumTreeDataHolder
    _epsilon: float

    def __init__(
        self,
        capacity: int,
        alpha: float = 0.6,
        beta: float = 0.4,
        betasteps: int = 10000,
        error_clip: Optional[Tuple[float, float]] = (-1, 1),
        epsilon: float = 1e-8,
        init_max_error: float = 1.0,
        normalization_method: str = "buffer_max",
    ):
        super(ProportionalPrioritizedReplayBuffer, self).__init__(capacity, alpha, beta, betasteps, error_clip)
        assert normalization_method in ("batch_max", "buffer_max")
        self._normalization_method = normalization_method
        keep_min = self._normalization_method == "buffer_max"
        self._buffer = SumTreeDataHolder(capacity=capacity, initial_max_priority=init_max_error, keep_min=keep_min)
        self._epsilon = epsilon

    def append(self, experience):
        self._buffer.append(experience)

    def sample(self, num_samples: int = 1, num_steps: int = 1):
        buffer_length = len(self)
        if num_samples > buffer_length:
            error_msg = f"num_samples: {num_samples} is greater than the size of buffer: {buffer_length}"
            raise ValueError(error_msg)
        if buffer_length - num_steps < 0:
            raise RuntimeError(f"Insufficient buffer length. buffer: {buffer_length} < steps: {num_steps}")

        # In paper,
        # "To sample a minibatch of size k, the range [0, ptotal] is divided equally into k ranges.
        # Next, a value is uniformly sampled from each range"
        indices = []
        interval = self._buffer.sum_priority() / num_samples
        for i in range(num_samples):
            index = sys.maxsize
            while index >= buffer_length - num_steps + 1:
                random_value = rl.random.drng.uniform(interval * i, interval * (i + 1))
                index = self._buffer.get_index_from_query(random_value)
            indices.append(index)
        return self.sample_indices(indices, num_steps)

    def update_priorities(self, errors: np.ndarray):
        errors = self._preprocess_errors(errors)
        errors = ((errors + self._epsilon) ** self._alpha).flatten()
        indices = cast(Sequence[int], self._last_sampled_indices)
        for index, error in zip(indices, errors):
            self._buffer.update_priority(index, error)
        self._last_sampled_indices = None

    def _get_weights(self, indices: Sequence[int], alpha: float, beta: float):
        priorities = np.asarray([self._buffer.get_priority(i) for i in indices])[:, np.newaxis]
        if self._normalization_method == "batch_max":
            # Use min priority. This is same as max of weight.
            min_priority = priorities.min()
        elif self._normalization_method == "buffer_max":
            # Use min priority. This is same as max of weight.
            min_priority = self._buffer.min_priority()
        else:
            raise RuntimeError(f"Unknown normalization method {self._normalization_method}")
        return (priorities / min_priority) ** (-beta)


class RankBasedPrioritizedReplayBuffer(_PrioritizedReplayBuffer):
    # type declarations to type check with mypy
    # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar
    # See https://mypy.readthedocs.io/en/stable/class_basics.html for details
    _buffer: MaxHeapDataHolder
    _reset_segment_interval: int
    _sort_interval: int
    _boundaries: List[int]
    _prev_num_samples: int
    _prev_num_steps: int
    _appends_since_prev_start: int

    def __init__(
        self,
        capacity: int,
        alpha: float = 0.7,
        beta: float = 0.5,
        betasteps: int = 10000,
        error_clip: Optional[Tuple[float, float]] = (-1, 1),
        reset_segment_interval: int = 1000,
        sort_interval: int = 1000000,
    ):
        super(RankBasedPrioritizedReplayBuffer, self).__init__(capacity, alpha, beta, betasteps, error_clip)
        self._buffer = MaxHeapDataHolder(capacity, alpha)

        self._reset_segment_interval = reset_segment_interval
        self._sort_interval = sort_interval

        self._boundaries = []
        self._prev_num_samples = 0
        self._prev_num_steps = 0
        self._appends_since_prev_sort = 0
        self._ps_cumsum = np.cumsum(np.asarray([(1 / (i + 1)) ** alpha for i in range(capacity)]))

    def append(self, experience):
        self._buffer.append(experience)

        self._appends_since_prev_sort += 1
        if self._appends_since_prev_sort % self._sort_interval == 0:
            self._buffer.sort_data()
            self._appends_since_prev_sort = 0

    def sample(self, num_samples: int = 1, num_steps: int = 1):
        buffer_length = len(self)
        if num_samples > buffer_length:
            error_msg = f"num_samples: {num_samples} is greater than the size of buffer: {buffer_length}"
            raise ValueError(error_msg)
        if buffer_length - num_steps < 0:
            raise RuntimeError(f"Insufficient buffer length. buffer: {buffer_length} < steps: {num_steps}")
        if (
            (num_samples != self._prev_num_samples)
            or (num_steps != self._prev_num_steps)
            or (buffer_length % self._reset_segment_interval == 0 and buffer_length != self._capacity)
            or (len(self._boundaries) == 0)
        ):
            self._boundaries = self._compute_segment_boundaries(N=buffer_length, k=num_samples)
            self._prev_num_samples = num_samples
            self._prev_num_steps = num_steps

        indices = []
        prev_boundary = 0
        for boundary in self._boundaries:
            heap_index = rl.random.drng.integers(low=prev_boundary, high=boundary)
            index = self._buffer.get_relative_index_from_heap_index(heap_index)
            prev_boundary = boundary
            if index < buffer_length - num_steps + 1:
                indices.append(index)
        while len(indices) < num_samples:
            # Enters here only when 1 < num_steps and (one or more than one) sampled indices exceeded buffer length
            boundary_index = rl.random.drng.choice(len(self._boundaries))
            if boundary_index != 0:
                boundary_low = self._boundaries[boundary_index - 1]
            else:
                boundary_low = 0
            boundary_high = self._boundaries[boundary_index]
            heap_index = rl.random.drng.integers(low=boundary_low, high=boundary_high)
            index = self._buffer.get_relative_index_from_heap_index(heap_index)
            if index < buffer_length - num_steps + 1:
                indices.append(index)
        return self.sample_indices(indices, num_steps)

    def update_priorities(self, errors: np.ndarray):
        errors = self._preprocess_errors(errors)
        indices = cast(Sequence[int], self._last_sampled_indices)
        for index, error in zip(indices, errors):
            self._buffer.update_priority(index, error)
        self._last_sampled_indices = None

    def _compute_segment_boundaries(self, N: int, k: int):
        if N < k:
            raise ValueError(f"Batch size {k} is greater than buffer size {N}")
        boundaries: List[int] = []
        denominator = self._ps_cumsum[N - 1]
        for i in range(N):
            if (len(boundaries) + 1) / k <= self._ps_cumsum[i] / denominator:
                boundaries.append(i + 1)
        assert len(boundaries) == k
        return boundaries

    def _get_weights(self, indices: Sequence[int], alpha: float, beta: float):
        priorities = np.asarray([self._buffer.get_priority(i) for i in indices])[:, np.newaxis]
        worst_rank = len(self._buffer)
        min_priority = (1 / worst_rank) ** alpha
        return (priorities / min_priority) ** (-beta)


[docs]class PrioritizedReplayBuffer(ReplayBuffer): _variants: ClassVar[Sequence[str]] = ["proportional", "rank_based"] _buffer_impl: _PrioritizedReplayBuffer def __init__( self, capacity: int, alpha: float = 0.6, beta: float = 0.4, betasteps: int = 10000, error_clip: Optional[Tuple[float, float]] = (-1, 1), epsilon: float = 1e-8, reset_segment_interval: int = 1000, sort_interval: int = 1000000, variant: str = "proportional", ): if variant not in PrioritizedReplayBuffer._variants: raise ValueError(f"Unknown prioritized replay buffer variant: {variant}") if variant == "proportional": self._buffer_impl = ProportionalPrioritizedReplayBuffer( capacity=capacity, alpha=alpha, beta=beta, betasteps=betasteps, error_clip=error_clip, epsilon=epsilon ) elif variant == "rank_based": self._buffer_impl = RankBasedPrioritizedReplayBuffer( capacity=capacity, alpha=alpha, beta=beta, betasteps=betasteps, error_clip=error_clip, reset_segment_interval=reset_segment_interval, sort_interval=sort_interval, ) else: raise NotImplementedError @property def capacity(self): return self._buffer_impl.capacity
[docs] def append(self, experience): self._buffer_impl.append(experience)
[docs] def append_all(self, experiences): self._buffer_impl.append_all(experiences)
[docs] def sample(self, num_samples: int = 1, num_steps: int = 1): return self._buffer_impl.sample(num_samples, num_steps)
[docs] def sample_indices(self, indices: Sequence[int], num_steps: int = 1): return self._buffer_impl.sample_indices(indices, num_steps)
def update_priorities(self, errors: np.ndarray): self._buffer_impl.update_priorities(errors) def __len__(self): return len(self._buffer_impl) def __getitem__(self, item: int) -> Experience: return cast(Experience, self._buffer_impl[item])