diff --git a/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py index 58ec2ba9474e2c..4af2925ebd8650 100644 --- a/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py +++ b/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py @@ -458,7 +458,7 @@ def tpu_hardware_feature(self): @property def environment(self): """Returns the current environment which TensorFlow is running in.""" - return self._environment + return '' def _start_local_server(self): address = compat.as_text(self._cloud_tpu_client.get_local_ip()) diff --git a/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver_test.py b/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver_test.py index 7be3249fe4f458..e11984412f9e48 100644 --- a/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver_test.py +++ b/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver_test.py @@ -719,6 +719,10 @@ def testTpuTopology(self): self.assertIsInstance(cluster_resolver.tpu_hardware_feature, topology_pb2.TPUHardwareFeature) + def testEnvironment(self): + cluster_resolver = resolver.TPUClusterResolver(tpu='local') + self.assertEqual(cluster_resolver.environment, '') + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/remote.py b/tensorflow/python/eager/remote.py index 2f70bfc86dd56e..2808f5a056f00a 100644 --- a/tensorflow/python/eager/remote.py +++ b/tensorflow/python/eager/remote.py @@ -187,11 +187,17 @@ def connect_to_cluster(cluster_spec_or_resolver, # tpu_cluster_resolver.TPUClusterResolver if (isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver) and hasattr(cluster_spec_or_resolver, "tpu_hardware_feature")): - is_uptc_sess = ".uptc-worker." in cluster_spec_or_resolver.master() - service_type = remote_utils.coordination_service_type( - protocol, is_uptc_sess) service_leader = cluster_spec_or_resolver.get_coordination_service_leader( ) + # Maybe enable coordination service internally. + if cluster_spec_or_resolver.environment == "google": + is_uptc_sess = ".uptc-worker." in cluster_spec_or_resolver.master() + service_type = remote_utils.coordination_service_type( + protocol, is_uptc_sess) + # Enable coordination service for Cloud TPU. + else: + service_type = "standalone" + if service_type: # If `enable_health_check` is true, coordination service agent would # do connecting (and tasks would send heartbeat if connection is set up)