From a92ec201e40f2fe17b900b8dd38af274a3da321d Mon Sep 17 00:00:00 2001 From: bonneted Date: Mon, 17 Jun 2024 17:11:02 +0200 Subject: [PATCH 1/5] fix external variable initialization --- deepxde/data/pde.py | 5 +++-- deepxde/model.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/deepxde/data/pde.py b/deepxde/data/pde.py index a81c5df82..6af9a93d2 100644 --- a/deepxde/data/pde.py +++ b/deepxde/data/pde.py @@ -147,8 +147,9 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None): elif get_num_args(self.pde) == 3: if self.auxiliary_var_fn is None: if aux is None or len(aux) == 1: - raise ValueError("Auxiliary variable function not defined.") - f = self.pde(inputs, outputs_pde, unknowns=aux[1]) + f = self.pde(inputs, outputs_pde) + else: + f = self.pde(inputs, outputs_pde, unknowns=aux[1]) else: f = self.pde(inputs, outputs_pde, model.net.auxiliary_vars) if not isinstance(f, (list, tuple)): diff --git a/deepxde/model.py b/deepxde/model.py index 844d6ec22..032df9c10 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -375,10 +375,10 @@ def _compile_jax(self, lr, loss_fn, decay): if self.params is None: key = jax.random.PRNGKey(config.jax_random_seed) self.net.params = self.net.init(key, self.data.test()[0]) - external_trainable_variables_arr = [ - var.value for var in self.external_trainable_variables - ] - self.params = [self.net.params, external_trainable_variables_arr] + external_trainable_variables_val = [ + var.value for var in self.external_trainable_variables + ] + self.params = [self.net.params, external_trainable_variables_val] # TODO: learning rate decay self.opt = optimizers.get(self.opt_name, learning_rate=lr) self.opt_state = self.opt.init(self.params) From 9570b874619c711a55a993897c409c5755bc309c Mon Sep 17 00:00:00 2001 From: bonneted Date: Tue, 18 Jun 2024 16:50:34 +0200 Subject: [PATCH 2/5] no auxiliary variable function case --- deepxde/data/pde.py | 4 +++- deepxde/utils/internal.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/deepxde/data/pde.py b/deepxde/data/pde.py index 6af9a93d2..613476e08 100644 --- a/deepxde/data/pde.py +++ b/deepxde/data/pde.py @@ -4,7 +4,7 @@ from .. import backend as bkd from .. import config from ..backend import backend_name -from ..utils import get_num_args, run_if_all_none, mpi_scatter_from_rank0 +from ..utils import get_num_args, has_default_values, run_if_all_none, mpi_scatter_from_rank0 class PDE(Data): @@ -147,6 +147,8 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None): elif get_num_args(self.pde) == 3: if self.auxiliary_var_fn is None: if aux is None or len(aux) == 1: + if not has_default_values(self.pde)[-1]: + raise ValueError("Auxiliary variable function not defined.") f = self.pde(inputs, outputs_pde) else: f = self.pde(inputs, outputs_pde, unknowns=aux[1]) diff --git a/deepxde/utils/internal.py b/deepxde/utils/internal.py index c2b162f11..cdacb7ff5 100644 --- a/deepxde/utils/internal.py +++ b/deepxde/utils/internal.py @@ -201,6 +201,9 @@ def get_num_args(func): params = inspect.signature(func).parameters return len(params) - ("self" in params) +def has_default_values(func): + params = inspect.signature(func).parameters.values() + return [param.default is not inspect._empty for param in params] def mpi_scatter_from_rank0(array, drop_last=True): """Scatter the given array into continuous subarrays of equal size from rank 0 to all ranks. From c119852916cefe7ac7a8a565626ebf3d8320d87d Mon Sep 17 00:00:00 2001 From: bonneted Date: Tue, 18 Jun 2024 17:53:11 +0200 Subject: [PATCH 3/5] fix inspect empty --- deepxde/utils/internal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepxde/utils/internal.py b/deepxde/utils/internal.py index cdacb7ff5..c633d9479 100644 --- a/deepxde/utils/internal.py +++ b/deepxde/utils/internal.py @@ -203,7 +203,7 @@ def get_num_args(func): def has_default_values(func): params = inspect.signature(func).parameters.values() - return [param.default is not inspect._empty for param in params] + return [param.default is not inspect.Parameter.empty for param in params] def mpi_scatter_from_rank0(array, drop_last=True): """Scatter the given array into continuous subarrays of equal size from rank 0 to all ranks. From de9e71b8566f9f3d75c12646c5baae34bb92baa3 Mon Sep 17 00:00:00 2001 From: bonneted Date: Mon, 8 Jul 2024 13:03:37 +0200 Subject: [PATCH 4/5] better unknowns logic for JAX inverse --- deepxde/data/pde.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/deepxde/data/pde.py b/deepxde/data/pde.py index 613476e08..97ff9b5c8 100644 --- a/deepxde/data/pde.py +++ b/deepxde/data/pde.py @@ -4,7 +4,12 @@ from .. import backend as bkd from .. import config from ..backend import backend_name -from ..utils import get_num_args, has_default_values, run_if_all_none, mpi_scatter_from_rank0 +from ..utils import ( + get_num_args, + has_default_values, + run_if_all_none, + mpi_scatter_from_rank0, +) class PDE(Data): @@ -145,15 +150,18 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None): if get_num_args(self.pde) == 2: f = self.pde(inputs, outputs_pde) elif get_num_args(self.pde) == 3: - if self.auxiliary_var_fn is None: - if aux is None or len(aux) == 1: - if not has_default_values(self.pde)[-1]: - raise ValueError("Auxiliary variable function not defined.") - f = self.pde(inputs, outputs_pde) - else: + if self.auxiliary_var_fn is not None: + f = self.pde(inputs, outputs_pde, model.net.auxiliary_vars) + elif backend_name == "jax": + # JAX inverse problem requires unknowns as the input. + if len(aux) == 2: + # External trainable variables in aux[1] are used for unknowns f = self.pde(inputs, outputs_pde, unknowns=aux[1]) + if len(aux) == 1 and has_default_values(self.pde)[-1]: + # No external trainable variables, default values are used for unknowns + f = self.pde(inputs, outputs_pde) else: - f = self.pde(inputs, outputs_pde, model.net.auxiliary_vars) + raise ValueError("Auxiliary variable function not defined.") if not isinstance(f, (list, tuple)): f = [f] From 1362634b8e92cc833677a0ce8bba78cf1f38b6f4 Mon Sep 17 00:00:00 2001 From: bonneted Date: Tue, 17 Dec 2024 13:47:52 +0100 Subject: [PATCH 5/5] no default unknowns error case --- deepxde/data/pde.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/deepxde/data/pde.py b/deepxde/data/pde.py index b2f1dbf94..72be31514 100644 --- a/deepxde/data/pde.py +++ b/deepxde/data/pde.py @@ -160,9 +160,13 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None): if len(aux) == 2: # External trainable variables in aux[1] are used for unknowns f = self.pde(inputs, outputs_pde, unknowns=aux[1]) - if len(aux) == 1 and has_default_values(self.pde)[-1]: + elif len(aux) == 1 and has_default_values(self.pde)[-1]: # No external trainable variables, default values are used for unknowns f = self.pde(inputs, outputs_pde) + else: + raise ValueError( + "Default unknowns are required if no trainable variables are provided." + ) else: raise ValueError("Auxiliary variable function not defined.") if not isinstance(f, (list, tuple)):