Source code for nnabla_rl.replay_buffers.prioritized_replay_buffer

# Copyright 2020,2021 Sony Corporation.
# Copyright 2021 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from dataclasses import dataclass

import numpy as np

from nnabla_rl.replay_buffer import ReplayBuffer


@dataclass
class Node:
    parent: int = -1
    left: int = 1
    right: int = 2
    value: float = 0.0


class SumTree(object):
    def __init__(self, capacity, init_max_p=1.0):
        self._capacity = capacity

        self._data = np.zeros(capacity, dtype=object)
        self._tree = [self._make_init_node(i) for i in range(2*capacity-1)]
        self._index = 0
        self._data_num = 0
        self._min_p = math.inf
        self._max_p = init_max_p
        self._latest_indices = None

    def _make_init_node(self, index):
        parent = (index - 1) // 2
        left = 2 * index + 1 if index < self._capacity - 1 else -1
        right = left + 1 if index < self._capacity - 1 else -1
        value = 0.
        return Node(parent, left, right, value)

    def append(self, data):
        self._data[self._index] = data
        self.update(self._index, self._max_p)

        self._index = (self._index + 1) % self._capacity
        if self._data_num < self._capacity:
            self._data_num += 1

    def update(self, index, p):
        tree_index = index + self._capacity - 1
        change_p = p - self._tree[tree_index].value
        self._tree[tree_index].value = float(p)
        self._update_parent(tree_index, change_p)

        self._min_p = min(self._min_p, p)
        self._max_p = max(self._max_p, p)

    def _update_parent(self, index, change_p):
        if index > 0:
            parent = self._tree[index].parent
            self._tree[parent].value += change_p
            self._update_parent(parent, change_p)

    def sample(self, num_samples=1, beta=0.6):
        random_values = np.random.uniform(0.0, self.total, size=num_samples)
        indices = [self._get_data_index_from_query(v) for v in random_values]
        return self.sample_indices(indices, beta)

    def sample_indices(self, indices, beta=0.6):
        if self._latest_indices is not None:
            raise RuntimeError('Trying to sample data from buffer without updating priority. '
                               'Check that the algorithm supports prioritized replay buffer.')
        data = [self._data[i] for i in indices]
        priorities = np.array([self._get_priority(i)
                               for i in indices])[:, np.newaxis]
        weights = self._weights_from_priorities(priorities, beta)
        self._latest_indices = indices
        return data, weights

    def _get_data_index_from_query(self, query):
        node = self._tree[0]
        while node.left >= 0:
            left_value = self._tree[node.left].value
            if query < left_value:
                index = node.left
            else:
                index = node.right
                query -= left_value
            node = self._tree[index]
        data_index = index - (self._capacity - 1)
        return data_index

    def _get_priority(self, index):
        tree_index = index + self._capacity - 1
        return self._tree[tree_index].value

    def _weights_from_priorities(self, priorities, beta):
        weights = (priorities / self._min_p) ** (-beta)
        return weights

    def update_latest_priorities(self, priorities):
        for index, priority in zip(self._latest_indices, priorities):
            self.update(index, priority)
        self._latest_indices = None

    def __len__(self):
        return self._data_num

    def __getitem__(self, index):
        return self._data[index]

    @property
    def total(self):
        return self._tree[0].value


[docs]class PrioritizedReplayBuffer(ReplayBuffer): def __init__( self, capacity, alpha=0.6, beta=0.4, betasteps=10000, epsilon=1e-8 ): # No need to call super class contructor self._capacity_check(capacity) self._capacity = capacity self._buffer = SumTree(capacity) self._alpha = alpha self._beta = beta self._beta_diff = (1.0 - beta) / betasteps self._epsilon = epsilon def _capacity_check(self, capacity): if capacity is None or capacity <= 0: error_msg = 'buffer size must be greater than 0' raise ValueError(error_msg)
[docs] def append(self, experience): self._buffer.append(experience)
[docs] def sample(self, num_samples=1): buffer_length = len(self) if num_samples > buffer_length: error_msg = 'num_samples: {} is greater than the size of buffer: {}'.format( num_samples, buffer_length) raise ValueError(error_msg) experiences, weights = self._buffer.sample(num_samples, self._beta) info = dict(weights=weights) self._beta = min(self._beta + self._beta_diff, 1.0) return experiences, info
[docs] def sample_indices(self, indices): if len(indices) == 0: raise ValueError('Indices are empty') experiences, weights = self._buffer.sample_indices(indices, self._beta) info = dict(weights=weights) self._beta = min(self._beta + self._beta_diff, 1.0) return experiences, info
def update_priorities(self, errors): priorities = ((errors + self._epsilon) ** self._alpha).flatten() self._buffer.update_latest_priorities(priorities) def __len__(self): return len(self._buffer) def __getitem__(self, index): return self._buffer[index]