Skip to content

Commit

Permalink
Fix - loading model on machine with diffrent accelerator (#1509)
Browse files Browse the repository at this point in the history
* Fix - loading model on machine with diffrent accelerator

* Fix - loading model on machine with diffrent accelerator - code style fixes

* add Optional type

---------

Co-authored-by: Oskar Triebe <ourownstory@users.noreply.github.com>
  • Loading branch information
McOffsky and ourownstory authored Jan 17, 2024
1 parent 3007765 commit 203840f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
9 changes: 7 additions & 2 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2734,7 +2734,12 @@ def _train(
metrics_df = pd.DataFrame(self.metrics_logger.history)
return metrics_df

def restore_trainer(self):
def restore_trainer(self, accelerator: Optional[str] = None):
"""
If no accelerator was provided, use accelerator stored in model.
"""
if accelerator is None:
accelerator = self.accelerator
"""
Restore the trainer based on the forecaster configuration.
"""
Expand All @@ -2743,7 +2748,7 @@ def restore_trainer(self):
config=self.trainer_config,
metrics_logger=self.metrics_logger,
early_stopping=self.early_stopping,
accelerator=self.accelerator,
accelerator=accelerator,
metrics_enabled=bool(self.metrics),
)

Expand Down
8 changes: 5 additions & 3 deletions neuralprophet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,12 @@ def load(path: str, map_location=None):
>>> from neuralprophet import load
>>> model = load("test_save_model.np")
"""
torch_map_location = None
if map_location is not None:
map_location = torch.device(map_location)
m = torch.load(path, map_location=map_location)
m.restore_trainer()
torch_map_location = torch.device(map_location)

m = torch.load(path, map_location=torch_map_location)
m.restore_trainer(accelerator=map_location)
return m


Expand Down

0 comments on commit 203840f

Please sign in to comment.