Skip to content

Commit

Permalink
add mixed precision support to deepxde
Browse files Browse the repository at this point in the history
  • Loading branch information
g-w1 committed Feb 11, 2024
1 parent 0643941 commit b8db5be
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
13 changes: 13 additions & 0 deletions deepxde/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ def set_default_float(value):
if value == "float16":
print("Set the default float type to float16")
real.set_float16()
elif value == "mixed":
print("Set training policy to mixed")
real.set_mixed()
if backend_name == "tensorflow":
real.set_float16()
tf.keras.mixed_precision.set_global_policy("mixed_float16")
elif backend_name == "pytorch":
# we cast to float16 during the passes in the training loop, but store in float32
real.set_float32()
else:
raise ValueError(
f"{backend_name} backend does not currently support mixed precision in deepXDE"
)
elif value == "float32":
print("Set the default float type to float32")
real.set_float32()
Expand Down
18 changes: 14 additions & 4 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,20 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):

def train_step(inputs, targets, auxiliary_vars):
def closure():
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = torch.sum(losses)
self.opt.zero_grad()
total_loss.backward()
if config.real.mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[
1
]
total_loss = torch.sum(losses)
# we do the backprop in float16
self.opt.zero_grad()
total_loss.backward()
else:
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = torch.sum(losses)
self.opt.zero_grad()
total_loss.backward()
return total_loss

self.opt.step(closure)
Expand Down
4 changes: 4 additions & 0 deletions deepxde/real.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class Real:
def __init__(self, precision):
self.precision = None
self.reals = None
self.mixed = False
if precision == 16:
self.set_float16()
elif precision == 32:
Expand All @@ -17,6 +18,9 @@ def __init__(self, precision):
def __call__(self, package):
return self.reals[package]

def set_mixed(self):
self.mixed = True

def set_float16(self):
self.precision = 16
self.reals = {np: np.float16, bkd.lib: bkd.float16}
Expand Down

0 comments on commit b8db5be

Please sign in to comment.