-
Notifications
You must be signed in to change notification settings - Fork 0
/
replay_buffer.py
31 lines (24 loc) · 1019 Bytes
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
class CircularReplayBuffer:
def __init__(self, capacity):
self._capacity = capacity
self._buffer = np.empty(capacity, dtype=object)
self._write_ptr = 0
self._size = 0
def add(self, ob, ob_nx, action, reward, reward_nx, done):
self._buffer[self._write_ptr] = [ob, ob_nx, action, reward, reward_nx, done]
self._size += 1 if self._size < self._capacity else 0
self._write_ptr += 1
self._write_ptr %= self._capacity
def sample(self, batch_size):
transitions = np.random.choice(self._buffer[:self._size], batch_size, replace=False)
ob, ob_nx, action, reward, reward_nx, done = zip(*transitions)
return np.array(ob), np.array(ob_nx), action, reward, reward_nx, done
def size(self):
return self.__len__()
def __len__(self):
return self._size
def clear(self):
self._buffer = np.empty(self._capacity, dtype=object)
self._write_ptr = 0
self._size = 0