Skip to content

Commit

Permalink
add example and update faq to include mixed precision
Browse files Browse the repository at this point in the history
  • Loading branch information
g-w1 committed Feb 11, 2024
1 parent 4fa1e80 commit 3e3ecbc
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/user/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ General usage
| **A**: `#5`_
- | **Q**: By default, DeepXDE uses ``float32``. How can I use ``float64``?
| **A**: `#28`_
| **Q**: How can I use mixed precision training?
| **A**: Use ``dde.config.set_default_float("mixed")`` with the ``tensorflow`` or ``pytorch`` backends. See `https://arxiv.org/abs/2401.16645` for more information.
- | **Q**: I want to set the global random seeds.
| **A**: `#353`_
- | **Q**: GPU.
Expand Down
53 changes: 53 additions & 0 deletions examples/pinn_forward/Burgers_mixed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Backend supported: tensorflow, pytorch
The exact same as Burgers.py, but using mixed precision instead of float32.
This preserves accuracy while speeding up training (especially with larger training runs).
"""

import deepxde as dde
import numpy as np

dde.config.set_default_float("mixed")

def gen_testdata():
data = np.load("../dataset/Burgers.npz")
t, x, exact = data["t"], data["x"], data["usol"].T
xx, tt = np.meshgrid(x, t)
X = np.vstack((np.ravel(xx), np.ravel(tt))).T
y = exact.flatten()[:, None]
return X, y


def pde(x, y):
dy_x = dde.grad.jacobian(y, x, i=0, j=0)
dy_t = dde.grad.jacobian(y, x, i=0, j=1)
dy_xx = dde.grad.hessian(y, x, i=0, j=0)
return dy_t + y * dy_x - 0.01 / np.pi * dy_xx


geom = dde.geometry.Interval(-1, 1)
timedomain = dde.geometry.TimeDomain(0, 0.99)
geomtime = dde.geometry.GeometryXTime(geom, timedomain)

bc = dde.icbc.DirichletBC(geomtime, lambda x: 0, lambda _, on_boundary: on_boundary)
ic = dde.icbc.IC(
geomtime, lambda x: -np.sin(np.pi * x[:, 0:1]), lambda _, on_initial: on_initial
)

data = dde.data.TimePDE(
geomtime, pde, [bc, ic], num_domain=2540, num_boundary=80, num_initial=160
)
net = dde.nn.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal")
model = dde.Model(data, net)

model.compile("adam", lr=1e-3)
model.train(iterations=15000)
model.compile("L-BFGS")
losshistory, train_state = model.train()
dde.saveplot(losshistory, train_state, issave=True, isplot=True)

X, y_true = gen_testdata()
y_pred = model.predict(X)
f = model.predict(X, operator=pde)
print("Mean residual:", np.mean(np.absolute(f)))
print("L2 relative error:", dde.metrics.l2_relative_error(y_true, y_pred))
np.savetxt("test.dat", np.hstack((X, y_true, y_pred)))

0 comments on commit 3e3ecbc

Please sign in to comment.