diff --git a/deepxde/utils/external.py b/deepxde/utils/external.py index bca1cc1d1..2e3a3bc5f 100644 --- a/deepxde/utils/external.py +++ b/deepxde/utils/external.py @@ -192,8 +192,10 @@ def plot_loss_history(loss_history, fname=None): fname (string): If `fname` is a string (e.g., 'loss_history.png'), then save the figure to the file of the file name `fname`. """ - loss_train = np.sum(loss_history.loss_train, axis=1) - loss_test = np.sum(loss_history.loss_test, axis=1) + # np.sum(loss_history.loss_train, axis=1) is error-prone for arrays of varying lengths. + # Handle irregular array sizes. + loss_train = np.array([np.sum(loss) for loss in loss_history.loss_train]) + loss_test = np.array([np.sum(loss) for loss in loss_history.loss_test]) plt.figure() plt.semilogy(loss_history.steps, loss_train, label="Train loss")