From 4fa1e8058228bdad0a6dc9e9b1415aba24fb3c8d Mon Sep 17 00:00:00 2001 From: Jacob G-W Date: Sun, 11 Feb 2024 17:11:35 -0500 Subject: [PATCH] add mixed precision support to deepxde --- deepxde/config.py | 15 +++++++++++++++ deepxde/model.py | 16 ++++++++++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/deepxde/config.py b/deepxde/config.py index 6d87e67b6..6c5406a9d 100644 --- a/deepxde/config.py +++ b/deepxde/config.py @@ -15,6 +15,7 @@ comm = None world_size = 1 rank = 0 +mixed = False if "OMPI_COMM_WORLD_SIZE" in os.environ: if backend_name == "tensorflow.compat.v1": import horovod.tensorflow as hvd @@ -79,6 +80,20 @@ def set_default_float(value): if value == "float16": print("Set the default float type to float16") real.set_float16() + elif value == "mixed": + print("Set training policy to mixed") + global mixed + mixed = True + if backend_name == "tensorflow": + real.set_float16() + tf.keras.mixed_precision.set_global_policy("mixed_float16") + elif backend_name == "pytorch": + # we cast to float16 during the passes in the training loop, but store in float32 + real.set_float32() + else: + raise ValueError( + f"{backend_name} backend does not currently support mixed precision in deepXDE" + ) elif value == "float32": print("Set the default float type to float32") real.set_float32() diff --git a/deepxde/model.py b/deepxde/model.py index 4ebdf6859..1676c0915 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -353,10 +353,18 @@ def outputs_losses_test(inputs, targets, auxiliary_vars): def train_step(inputs, targets, auxiliary_vars): def closure(): - losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] - total_loss = torch.sum(losses) - self.opt.zero_grad() - total_loss.backward() + if config.mixed: + with torch.autocast(device_type="cuda", dtype=torch.float16): + losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] + total_loss = torch.sum(losses) + # we do the backprop in float16 + self.opt.zero_grad() + total_loss.backward() + else: + losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] + total_loss = torch.sum(losses) + self.opt.zero_grad() + total_loss.backward() return total_loss self.opt.step(closure)