Skip to content

Commit

Permalink
[minor] Support io.bytes output (#1583)
Browse files Browse the repository at this point in the history
* add io buffer support

* added tests

* remove typealias

---------

Co-authored-by: Maisa Ben Salah <maisabensalah@AminsMBP131.attlocal.net>
  • Loading branch information
MaiBe-ctrl and Maisa Ben Salah authored Jun 17, 2024
1 parent ac7773a commit d9b9cd5
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 8 deletions.
20 changes: 12 additions & 8 deletions neuralprophet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import sys
from collections import OrderedDict
from typing import TYPE_CHECKING, Iterable, Optional, Union
from typing import TYPE_CHECKING, Iterable, Optional, Union, BinaryIO, IO

import numpy as np
import pandas as pd
Expand All @@ -21,19 +21,23 @@

log = logging.getLogger("NP.utils")

FILE_LIKE = Union[str, os.PathLike, BinaryIO, IO[bytes]]

def save(forecaster, path: str):
def save(forecaster, path: FILE_LIKE):
"""Save a fitted Neural Prophet model to disk.
Parameters:
forecaster : np.forecaster.NeuralProphet
input forecaster that is fitted
path : str
path and filename to be saved. filename could be any but suggested to have extension .np.
path : FILE_LIKE
Path and filename to be saved, or an in-memory buffer. Filename could be any but suggested to have extension .np.
After you fitted a model, you may save the model to save_test_model.np
>>> from neuralprophet import save
>>> save(forecaster, "test_save_model.np")
>>> import io
>>> buffer = io.BytesIO()
>>> save(forecaster, buffer)
"""
# List of attributes to remove
attrs_to_remove_forecaster = ["trainer"]
Expand Down Expand Up @@ -69,13 +73,13 @@ def save(forecaster, path: str):
setattr(forecaster.model, attr, value)


def load(path: str, map_location=None):
"""retrieve a fitted model from a .np file that was saved by save.
def load(path: FILE_LIKE, map_location=None):
"""retrieve a fitted model from a .np file or buffer that was saved by save.
Parameters
----------
path : str
path and filename to be saved. filename could be any but suggested to have extension .np.
path : FILE_LIKE
Path and filename to be saved, or an in-memory buffer. Filename could be any but suggested to have extension .np.
map_location : str, optional
specifying the location where the model should be loaded.
If you are running on a CPU-only machine, set map_location='cpu' to map your storages to the CPU.
Expand Down
32 changes: 32 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import pathlib
import io

import pandas as pd
import pytest
Expand Down Expand Up @@ -66,6 +67,37 @@ def test_save_load():
pd.testing.assert_frame_equal(forecast, forecast2)
pd.testing.assert_frame_equal(forecast, forecast3)

def test_save_load_io():
df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
m = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
n_lags=6,
n_forecasts=3,
n_changepoints=0,
)
_ = m.fit(df, freq="D")
future = m.make_future_dataframe(df, periods=3)
forecast = m.predict(df=future)

# Save the model to an in-memory buffer
log.info("testing: save to buffer")
buffer = io.BytesIO()
save(m, buffer)
buffer.seek(0) # Reset buffer position to the beginning

log.info("testing: load from buffer")
m2 = load(buffer)
forecast2 = m2.predict(df=future)

buffer.seek(0) # Reset buffer position to the beginning for another load
m3 = load(buffer, map_location="cpu")
forecast3 = m3.predict(df=future)

# Check that the forecasts are the same
pd.testing.assert_frame_equal(forecast, forecast2)
pd.testing.assert_frame_equal(forecast, forecast3)

# TODO: add functionality to continue training
# def test_continue_training():
Expand Down

0 comments on commit d9b9cd5

Please sign in to comment.