diff --git a/domainlab/algos/trainers/train_ema.py b/domainlab/algos/trainers/train_ema.py index a935b6570..b2ccede40 100644 --- a/domainlab/algos/trainers/train_ema.py +++ b/domainlab/algos/trainers/train_ema.py @@ -63,4 +63,4 @@ def after_epoch(self, epoch, flag_info=None): new_dict_para = self.move_average(dict_para, epoch) # without deepcopy, this seems to work torch_model.load_state_dict(new_dict_para) - super().after_epoch(epoch) + super().after_epoch(epoch, flag_info)