Skip to content

Commit

Permalink
Make update_counter multiprocess safe (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
belldandyxtq authored Jul 22, 2020
1 parent 67ce265 commit d420891
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions pfrl/agents/dqn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import collections
import time
import ctypes
import multiprocessing as mp
from logging import getLogger

Expand Down Expand Up @@ -240,11 +241,13 @@ def cumulative_steps(self):
# cumulative_steps counts the overall steps during the training.
return self._cumulative_steps

def _setup_actor_learner_training(self, n_actors, actor_update_interval):
def _setup_actor_learner_training(
self, n_actors, actor_update_interval, update_counter
):
assert actor_update_interval > 0

self.actor_update_interval = actor_update_interval
self.update_counter = 0
self.update_counter = update_counter

# Make a copy on shared memory and share among actors and the poller
shared_model = copy.deepcopy(self.model).cpu()
Expand Down Expand Up @@ -603,8 +606,9 @@ def _learner_loop(
# intervals.
update_counter += 1
if update_counter % self.actor_update_interval == 0:
self.update_counter += 1
shared_model.load_state_dict(self.model.state_dict())
with self.update_counter.get_lock():
self.update_counter.value += 1
shared_model.load_state_dict(self.model.state_dict())

# To keep the ratio of target updates to model updates,
# here we calculate back the effective current timestep
Expand All @@ -630,10 +634,13 @@ def _poller_loop(
self._poll_pipe(i, pipe, replay_buffer_lock, exception_event)

def setup_actor_learner_training(
self, n_actors, n_updates=None, actor_update_interval=8
self, n_actors, update_counter=None, n_updates=None, actor_update_interval=8
):
if update_counter is None:
update_counter = mp.Value(ctypes.c_ulong)

(shared_model, learner_pipes, actor_pipes) = self._setup_actor_learner_training(
n_actors, actor_update_interval
n_actors, actor_update_interval, update_counter
)
exception_event = mp.Event()

Expand Down

0 comments on commit d420891

Please sign in to comment.