diff --git a/src/speckcn2/loss.py b/src/speckcn2/loss.py index 082a938..a3bef84 100644 --- a/src/speckcn2/loss.py +++ b/src/speckcn2/loss.py @@ -57,7 +57,11 @@ def __init__(self, validation: bool = False): super(ComposableLoss, self).__init__() if validation: - config['loss'] = config['val_loss'] + if 'val_loss' in config: + config['loss'] = config['val_loss'] + else: + print('[!] Warning: Validation loss not found in config.yaml,', + 'keeping track of training loss instead') self.device = device self.loss_functions: dict[str, Callable] = { 'MSE': torch.nn.MSELoss(),