diff --git a/src/training.py b/src/training.py index 0b601e4..d8ad72c 100644 --- a/src/training.py +++ b/src/training.py @@ -746,6 +746,8 @@ def trace_handler(prof): accelerator.save_state(config['checkpointdir'] / f'best_state') unwrapped_model = accelerator.unwrap_model(model) torch.save(unwrapped_model.state_dict(), config['checkpointdir'] / f"best_model.bin") + torch.save(unwrapped_model.get_encoder().state_dict(), config['checkpointdir'] / f"best_encoder.bin") + torch.save(unwrapped_model.get_decoder().state_dict(), config['checkpointdir'] / f"best_decoder.bin") checkpoint = Checkpoint.from_directory(checkpointdir)