diff --git a/easy_rec/python/core/sampler.py b/easy_rec/python/core/sampler.py index a3f8bf7fd..cb6d04e8c 100644 --- a/easy_rec/python/core/sampler.py +++ b/easy_rec/python/core/sampler.py @@ -79,7 +79,10 @@ def _init_graph(self): if 'ps' in tf_config['cluster']: # ps mode tf_config = json.loads(os.environ['TF_CONFIG']) - task_count = len(tf_config['cluster']['worker']) + 2 + if 'worker' in tf_config['cluster']: + task_count = len(tf_config['cluster']['worker']) + 2 + else: + task_count = 2 if self._is_on_ds: gl.set_tracker_mode(0) server_hosts = [ diff --git a/requirements/docs.txt b/requirements/docs.txt index a81d0986b..596bd527b 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -5,4 +5,4 @@ recommonmark==0.6.0 sphinx==5.1.1 sphinx_markdown_tables==0.0.17 sphinx_rtd_theme -tensorflow-probability \ No newline at end of file +tensorflow-probability==0.11.0 \ No newline at end of file