diff --git a/deepxde/data/pde_operator.py b/deepxde/data/pde_operator.py index 3b001cf9a..78c7d5393 100644 --- a/deepxde/data/pde_operator.py +++ b/deepxde/data/pde_operator.py @@ -263,18 +263,29 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None): # Use stack instead of as_tensor to keep the gradients. losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses] elif config.autodiff == "forward": # forward mode AD + shape0, shape1 = outputs.shape[0], outputs.shape[1] + shape2 = 1 if model.net.num_outputs == 1 else outputs.shape[2] def forward_call(trunk_input): - return aux[0]((inputs[0], trunk_input)) + output = aux[0]((inputs[0], trunk_input)) + return bkd.reshape(output, (shape0 * shape1, shape2)) f = [] if self.pde.pde is not None: # Each f has the shape (N1, N2) f = self.pde.pde( - inputs[1], (outputs, forward_call), model.net.auxiliary_vars + inputs[1], + (bkd.reshape(outputs, (shape0 * shape1, shape2)), forward_call), + bkd.reshape(model.net.auxiliary_vars, (shape0 * shape1, shape2)), ) if not isinstance(f, (list, tuple)): f = [f] + f = ( + [bkd.reshape(fi, (shape0, shape1)) for fi in f] + if model.net.num_outputs == 1 + else [bkd.reshape(fi, (shape0, shape1, shape2)) for fi in f] + ) + # Each error has the shape (N1, ~N2) error_f = [fi[:, bcs_start[-1] :] for fi in f] for error in error_f: @@ -307,7 +318,7 @@ def forward_call(trunk_input): losses_bc = zip(*losses_bc) losses_bc = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses_bc] - losses.append(losses_bc) + losses.extend(losses_bc) return losses def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None):