Skip to content

Commit

Permalink
save encoder and decoder models independently
Browse files Browse the repository at this point in the history
  • Loading branch information
thayeral committed Jan 7, 2025
1 parent 76ca38d commit 914cce3
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,8 @@ def trace_handler(prof):
accelerator.save_state(config['checkpointdir'] / f'best_state')
unwrapped_model = accelerator.unwrap_model(model)
torch.save(unwrapped_model.state_dict(), config['checkpointdir'] / f"best_model.bin")
torch.save(unwrapped_model.get_encoder().state_dict(), config['checkpointdir'] / f"best_encoder.bin")
torch.save(unwrapped_model.get_decoder().state_dict(), config['checkpointdir'] / f"best_decoder.bin")

checkpoint = Checkpoint.from_directory(checkpointdir)

Expand Down

0 comments on commit 914cce3

Please sign in to comment.