Inverse PDE in jax #760
Replies: 4 comments 3 replies
-
Good question. I haven't looked into this yet, and am not sure what the best solution is. @ZongrenZou Any ideas? |
Beta Was this translation helpful? Give feedback.
-
@outspeckle We just added the support for inverse problem in JAX. You may want to check the "Lorenz_inverse.py" example in "pinn_inverse" folder. Let us know if you have any other questions. As for the design behind this, we didn't use flax.core.variables.Variable. We simply treat unknowns in inverse problems as JAX arrays. The reason behind it is that, unlike other backends, JAX traces (trainable) variables through functions and their arguments, while, for example, in Tensorflow, all variables, once created, are traced automatically and hence people don't need to trace them explicitly. |
Beta Was this translation helpful? Give feedback.
-
This is really great, thank you very much! The approach via implicit tracing works very well, its easily possible to access them via model.external_trainable_variables. |
Beta Was this translation helpful? Give feedback.
-
Good morning @ZongrenZou @lululxvi,
Here is the output.
And with JAX (no variable update):
Do you have the same issue? |
Beta Was this translation helpful? Give feedback.
-
Hello Dr. Lu Lu and Deepxde community,
What do you think would be the best way to use dde.Variable in jax? I tried adding adding a flax.core.variables.Variable for the free parameters and include it in self.net.params in model.py (line 328), but have problems initalizing it correctly. Or would it be better to handle it separately from flax and include it in the loss later?
Beta Was this translation helpful? Give feedback.
All reactions