Skip to content

Commit

Permalink
add mixed precision support to deepxde
Browse files Browse the repository at this point in the history
  • Loading branch information
g-w1 committed Feb 11, 2024
1 parent 0643941 commit 4fa1e80
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
15 changes: 15 additions & 0 deletions deepxde/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
comm = None
world_size = 1
rank = 0
mixed = False
if "OMPI_COMM_WORLD_SIZE" in os.environ:
if backend_name == "tensorflow.compat.v1":
import horovod.tensorflow as hvd
Expand Down Expand Up @@ -79,6 +80,20 @@ def set_default_float(value):
if value == "float16":
print("Set the default float type to float16")
real.set_float16()
elif value == "mixed":
print("Set training policy to mixed")
global mixed
mixed = True
if backend_name == "tensorflow":
real.set_float16()
tf.keras.mixed_precision.set_global_policy("mixed_float16")
elif backend_name == "pytorch":
# we cast to float16 during the passes in the training loop, but store in float32
real.set_float32()
else:
raise ValueError(
f"{backend_name} backend does not currently support mixed precision in deepXDE"
)
elif value == "float32":
print("Set the default float type to float32")
real.set_float32()
Expand Down
16 changes: 12 additions & 4 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,18 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):

def train_step(inputs, targets, auxiliary_vars):
def closure():
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = torch.sum(losses)
self.opt.zero_grad()
total_loss.backward()
if config.mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = torch.sum(losses)
# we do the backprop in float16
self.opt.zero_grad()
total_loss.backward()
else:
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = torch.sum(losses)
self.opt.zero_grad()
total_loss.backward()
return total_loss

self.opt.step(closure)
Expand Down

0 comments on commit 4fa1e80

Please sign in to comment.