forked from keon/deep-q-learning
-
Notifications
You must be signed in to change notification settings - Fork 1
/
updates.py
29 lines (23 loc) · 853 Bytes
/
updates.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
class Rule:
def train(model, minibatch):
raise NotImplementedError
def __call__():
raise NotImplementedError
class Q_learning:
def __init__(self, config):
self.gamma = config.gamma
def __call__(self, model, minibatch):
states, targets_f = [], []
for state, action, reward, next_state, done in minibatch:
target = reward
if not done:
# Bellman equation
target = reward + self.gamma * np.amax(model.predict(next_state)[0])
target_f = model.predict(state)
target_f[0][action] = target
# Filtering out states and targets for training
states.append(state[0])
targets_f.append(target_f[0])
history = model.train(states, targets_f)
return history