Skip to content

Commit

Permalink
Pass number of threads in task sampler process as nthread
Browse files Browse the repository at this point in the history
  • Loading branch information
jordis-ai2 committed May 8, 2024
1 parent 19740c1 commit 59d2378
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,13 +1644,14 @@ def _task_sampling_loop_worker(
auto_resample_when_done: bool,
should_log: bool,
thread_barrier: threading.Barrier,
thread_num: int,
) -> None:
"""process worker for creating and interacting with the
Tasks/TaskSampler."""
assert len(sampler_fn_args_list) == 1

sampler_fn_args_list = [
{**cur_kwargs, "thread_id": worker_id, "thread_barrier": thread_barrier}
{**cur_kwargs, "thread_id": worker_id, "thread_barrier": thread_barrier, "nthread": thread_num}
for cur_kwargs in sampler_fn_args_list
]

Expand Down Expand Up @@ -1751,6 +1752,7 @@ def _start_workers(
auto_resample_when_done=self._auto_resample_when_done,
should_log=self.should_log,
thread_barrier=barrier,
thread_num=self._num_task_samplers,
),
)

Expand Down

0 comments on commit 59d2378

Please sign in to comment.