Skip to content

Commit

Permalink
add mixed precision support to deepxde
Browse files Browse the repository at this point in the history
  • Loading branch information
g-w1 committed Feb 12, 2024
1 parent 0643941 commit 04bcd47
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
15 changes: 14 additions & 1 deletion deepxde/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def set_default_float(value):
The default floating point type is 'float32'.
Args:
value (String): 'float16', 'float32', or 'float64'.
value (String): 'float16', 'float32', 'float64', or 'mixed' (mixed precision in https://arxiv.org/abs/2401.16645).
"""
if value == "float16":
print("Set the default float type to float16")
Expand All @@ -85,6 +85,19 @@ def set_default_float(value):
elif value == "float64":
print("Set the default float type to float64")
real.set_float64()
elif value == "mixed":
print("Set the float type to mixed precision of float16 and float32")
real.set_mixed()
if backend_name == "tensorflow":
real.set_float16()
tf.keras.mixed_precision.set_global_policy("mixed_float16")
elif backend_name == "pytorch":
# Use float16 during the forward and backward passes, but store in float32
real.set_float32()
else:
raise ValueError(
f"{backend_name} backend does not currently support mixed precision."
)
else:
raise ValueError(f"{value} not supported in deepXDE")
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:
Expand Down
18 changes: 14 additions & 4 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,20 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):

def train_step(inputs, targets, auxiliary_vars):
def closure():
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = torch.sum(losses)
self.opt.zero_grad()
total_loss.backward()
if config.real.mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[
1
]
total_loss = torch.sum(losses)
# we do the backprop in float16
self.opt.zero_grad()
total_loss.backward()
else:
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = torch.sum(losses)
self.opt.zero_grad()
total_loss.backward()
return total_loss

self.opt.step(closure)
Expand Down
4 changes: 4 additions & 0 deletions deepxde/real.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class Real:
def __init__(self, precision):
self.precision = None
self.reals = None
self.mixed = False
if precision == 16:
self.set_float16()
elif precision == 32:
Expand All @@ -28,3 +29,6 @@ def set_float32(self):
def set_float64(self):
self.precision = 64
self.reals = {np: np.float64, bkd.lib: bkd.float64}

def set_mixed(self):
self.mixed = True

0 comments on commit 04bcd47

Please sign in to comment.