-
Notifications
You must be signed in to change notification settings - Fork 4
/
memory_buffers.py
100 lines (79 loc) · 2.39 KB
/
memory_buffers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import random
from collections import deque
from abc import abstractmethod
class MemoryTemplate:
"""
Memory abstract class
"""
_counter = 0
def __init__(self, seed):
if seed is not None:
random.seed(seed)
@property
def counter(self):
return self._counter
@abstractmethod
def __len__(self):
pass
@abstractmethod
def append(self, element):
# remember to call to _inc_counter when appending
pass
@abstractmethod
def sample(self, n, or_less):
pass
def _inc_counter(self, inc_by=1):
self._counter += inc_by
def _get_n_or_less(self, n, or_less):
if or_less and n > self._counter:
result = self._counter
else:
result = n
return result
class ExperienceReplayMemory(MemoryTemplate):
"""
A cyclic-buffer Experience Replay memory
"""
_memory = None
def __init__(self, size, seed=None):
"""
Create a new Experience Replay Memory
:param size: memory size
:param seed: random seed to be used (will override random.seed)
"""
super(ExperienceReplayMemory, self).__init__(seed)
self._memory = deque(maxlen=size)
def __len__(self):
return len(self._memory)
def append(self, element):
self._memory.append(element)
self._inc_counter()
def sample(self, n, or_less=False):
n = self._get_n_or_less(n, or_less)
return random.sample(self._memory, n)
class ReservoirSamplingMemory(MemoryTemplate):
"""
Reservoir Sampling based memory buffer
"""
_memory = list()
_max_size = 0
def __init__(self, size, seed=None):
"""
Create a new Reservoir Sampling Memory
:param size: memory size
:param seed: random seed to be used (will override random.seed)
"""
super(ReservoirSamplingMemory, self).__init__(seed)
self._max_size = size
def __len__(self):
return len(self._memory)
def append(self, element):
if len(self._memory) < self._max_size:
self._memory.append(element)
else:
i = int(random.random() * self._counter)
if i < self._max_size:
self._memory[i] = element
def sample(self, n ,or_less=False):
n = self._get_n_or_less(n,or_less)
return random.sample(self._memory, n)