-
Notifications
You must be signed in to change notification settings - Fork 39
/
utils.py
27 lines (24 loc) · 1.12 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import tensorflow as tf
def detect_hardware(tpu_name):
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_name) # TPU detection
except ValueError:
tpu = None
gpus = tf.config.experimental.list_logical_devices("GPU")
# Select appropriate distribution strategy
if tpu:
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.TPUStrategy(tpu)
print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
elif len(gpus) > 1:
strategy = tf.distribute.MirroredStrategy([gpu.name for gpu in gpus])
print('Running on multiple GPUs ', [gpu.name for gpu in gpus])
elif len(gpus) == 1:
strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
print('Running on single GPU ', gpus[0].name)
else:
strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
print('Running on CPU')
print("Number of accelerators: ", strategy.num_replicas_in_sync)
return tpu, strategy