diff --git a/numalogic/registry/_serialize.py b/numalogic/registry/_serialize.py index 8c726a37..6d966bc8 100644 --- a/numalogic/registry/_serialize.py +++ b/numalogic/registry/_serialize.py @@ -1,14 +1,25 @@ import io +import pickle +from typing import Union + import torch +from numalogic.tools.types import artifact_t, state_dict_t + -# TODO: ADD other techniques and support for other serialization techniques -def dumps(deserialized_object): - buf = io.BytesIO() - torch.save(deserialized_object, buf) - return buf.getvalue() +def dumps( + deserialized_object: Union[artifact_t, state_dict_t], + pickle_protocol: int = pickle.HIGHEST_PROTOCOL, +) -> bytes: + buffer = io.BytesIO() + torch.save(deserialized_object, buffer, pickle_protocol=pickle_protocol) + serialized_obj = buffer.getvalue() + buffer.close() + return serialized_obj -def loads(serialized_object): +def loads(serialized_object: bytes) -> Union[artifact_t, state_dict_t]: buffer = io.BytesIO(serialized_object) - return torch.load(buffer) + deserialized_obj = torch.load(buffer) + buffer.close() + return deserialized_obj diff --git a/tests/registry/test_serialize.py b/tests/registry/test_serialize.py index 685047a4..274fb3d7 100644 --- a/tests/registry/test_serialize.py +++ b/tests/registry/test_serialize.py @@ -1,11 +1,12 @@ +import pickle +import timeit import unittest from sklearn.preprocessing import StandardScaler from torchinfo import summary -from numalogic.registry._serialize import loads, dumps - from numalogic.models.autoencoder.variants import VanillaAE +from numalogic.registry._serialize import loads, dumps class TestSerialize(unittest.TestCase): @@ -21,3 +22,19 @@ def test_dumps_loads2(self): serialized_obj = dumps(model) deserialized_obj = loads(serialized_obj) self.assertEqual(model.mean_, deserialized_obj.mean_) + + def test_benchmark_state_dict_vs_model(self): + model = VanillaAE(10, 2) + serialized_sd = dumps(model.state_dict()) + serialized_obj = dumps(model) + elapsed_obj = timeit.timeit(lambda: loads(serialized_obj), number=100) + elapsed_sd = timeit.timeit(lambda: loads(serialized_sd), number=100) + self.assertLess(elapsed_sd, elapsed_obj) + + def test_benchmark_default_vs_highest_protocol(self): + model = VanillaAE(10, 2) + serialized_default = dumps(model, pickle_protocol=pickle.DEFAULT_PROTOCOL) + serialized_highest = dumps(model, pickle_protocol=pickle.HIGHEST_PROTOCOL) + elapsed_default = timeit.timeit(lambda: loads(serialized_default), number=100) + elapsed_highest = timeit.timeit(lambda: loads(serialized_highest), number=100) + self.assertLess(elapsed_highest, elapsed_default)