Skip to content

Commit

Permalink
[minor] Torch Load (#1473)
Browse files Browse the repository at this point in the history
* added map_location

* pytest

---------

Co-authored-by: Oskar Triebe <ourownstory@users.noreply.github.com>
  • Loading branch information
SimonWittner and ourownstory authored Nov 7, 2023
1 parent 07724b1 commit 7aac3be
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
12 changes: 9 additions & 3 deletions neuralprophet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,18 @@ def save(forecaster, path: str):
setattr(forecaster.model, attr, value)


def load(path: str):
def load(path: str, map_location=None):
"""retrieve a fitted model from a .np file that was saved by save.
Parameters
----------
path : str
path and filename to be saved. filename could be any but suggested to have extension .np.
map_location : str, optional
specifying the location where the model should be loaded.
If you are running on a CPU-only machine, set map_location='cpu' to map your storages to the CPU.
If you are running on CUDA, set map_location='cuda:device_id' (e.g. 'cuda:2').
Default is None, which means the model is loaded to the same device as it was saved on.
Returns
-------
np.forecaster.NeuralProphet
Expand All @@ -88,7 +92,9 @@ def load(path: str):
>>> from neuralprophet import load
>>> model = load("test_save_model.np")
"""
m = torch.load(path)
if map_location is not None:
map_location = torch.device(map_location)
m = torch.load(path, map_location=map_location)
m.restore_trainer()
return m

Expand Down
4 changes: 4 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,12 @@ def test_save_load():
m2 = load("test_model.pt")
forecast2 = m2.predict(df=future)

m3 = load("test_model.pt", map_location="cpu")
forecast3 = m3.predict(df=future)

# Check that the forecasts are the same
pd.testing.assert_frame_equal(forecast, forecast2)
pd.testing.assert_frame_equal(forecast, forecast3)


# TODO: add functionality to continue training
Expand Down

0 comments on commit 7aac3be

Please sign in to comment.