diff --git a/deepxde/data/pde.py b/deepxde/data/pde.py index 67f900648..72be31514 100644 --- a/deepxde/data/pde.py +++ b/deepxde/data/pde.py @@ -4,7 +4,12 @@ from .. import backend as bkd from .. import config from ..backend import backend_name -from ..utils import get_num_args, run_if_all_none, mpi_scatter_from_rank0 +from ..utils import ( + get_num_args, + has_default_values, + run_if_all_none, + mpi_scatter_from_rank0, +) class PDE(Data): @@ -150,9 +155,18 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None): elif get_num_args(self.pde) == 3: if self.auxiliary_var_fn is not None: f = self.pde(inputs, outputs_pde, model.net.auxiliary_vars) - elif backend_name == "jax" and len(aux) == 2: + elif backend_name == "jax": # JAX inverse problem requires unknowns as the input. - f = self.pde(inputs, outputs_pde, unknowns=aux[1]) + if len(aux) == 2: + # External trainable variables in aux[1] are used for unknowns + f = self.pde(inputs, outputs_pde, unknowns=aux[1]) + elif len(aux) == 1 and has_default_values(self.pde)[-1]: + # No external trainable variables, default values are used for unknowns + f = self.pde(inputs, outputs_pde) + else: + raise ValueError( + "Default unknowns are required if no trainable variables are provided." + ) else: raise ValueError("Auxiliary variable function not defined.") if not isinstance(f, (list, tuple)): diff --git a/deepxde/model.py b/deepxde/model.py index 1644c2529..e1d898506 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -401,10 +401,10 @@ def _compile_jax(self, lr, loss_fn, decay): if self.params is None: key = jax.random.PRNGKey(config.jax_random_seed) self.net.params = self.net.init(key, self.data.test()[0]) - external_trainable_variables_arr = [ - var.value for var in self.external_trainable_variables - ] - self.params = [self.net.params, external_trainable_variables_arr] + external_trainable_variables_val = [ + var.value for var in self.external_trainable_variables + ] + self.params = [self.net.params, external_trainable_variables_val] # TODO: learning rate decay self.opt = optimizers.get(self.opt_name, learning_rate=lr) self.opt_state = self.opt.init(self.params) diff --git a/deepxde/utils/internal.py b/deepxde/utils/internal.py index 1b57f9d0d..dfa1a55f0 100644 --- a/deepxde/utils/internal.py +++ b/deepxde/utils/internal.py @@ -202,6 +202,18 @@ def get_num_args(func): params = inspect.signature(func).parameters return len(params) - ("self" in params) +def has_default_values(func): + """ + Check if the given function has default values for its parameters. + + Args: + func (function): The function to inspect. + + Returns: + list: A list of boolean values indicating whether each parameter has a default value. + """ + params = inspect.signature(func).parameters.values() + return [param.default is not inspect.Parameter.empty for param in params] def mpi_scatter_from_rank0(array, drop_last=True): """Scatter the given array into continuous subarrays of equal size from rank 0 to all ranks.