Skip to content

Commit

Permalink
perf: improve serialization performance
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <ab93@users.noreply.github.com>
  • Loading branch information
ab93 committed Aug 16, 2023
1 parent 7ef40a0 commit 0742ce1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 9 deletions.
25 changes: 18 additions & 7 deletions numalogic/registry/_serialize.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
import io
import pickle
from typing import Union

import torch

from numalogic.tools.types import artifact_t, state_dict_t


# TODO: ADD other techniques and support for other serialization techniques
def dumps(deserialized_object):
buf = io.BytesIO()
torch.save(deserialized_object, buf)
return buf.getvalue()
def dumps(
deserialized_object: Union[artifact_t, state_dict_t],
pickle_protocol: int = pickle.HIGHEST_PROTOCOL,
) -> bytes:
buffer = io.BytesIO()
torch.save(deserialized_object, buffer, pickle_protocol=pickle_protocol)
serialized_obj = buffer.getvalue()
buffer.close()
return serialized_obj


def loads(serialized_object):
def loads(serialized_object: bytes) -> Union[artifact_t, state_dict_t]:
buffer = io.BytesIO(serialized_object)
return torch.load(buffer)
deserialized_obj = torch.load(buffer)
buffer.close()
return deserialized_obj
21 changes: 19 additions & 2 deletions tests/registry/test_serialize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import pickle
import timeit
import unittest

from sklearn.preprocessing import StandardScaler
from torchinfo import summary

from numalogic.registry._serialize import loads, dumps

from numalogic.models.autoencoder.variants import VanillaAE
from numalogic.registry._serialize import loads, dumps


class TestSerialize(unittest.TestCase):
Expand All @@ -21,3 +22,19 @@ def test_dumps_loads2(self):
serialized_obj = dumps(model)
deserialized_obj = loads(serialized_obj)
self.assertEqual(model.mean_, deserialized_obj.mean_)

def test_benchmark_state_dict_vs_model(self):
model = VanillaAE(10, 2)
serialized_sd = dumps(model.state_dict())
serialized_obj = dumps(model)
elapsed_obj = timeit.timeit(lambda: loads(serialized_obj), number=100)
elapsed_sd = timeit.timeit(lambda: loads(serialized_sd), number=100)
self.assertLess(elapsed_sd, elapsed_obj)

def test_benchmark_default_vs_highest_protocol(self):
model = VanillaAE(10, 2)
serialized_default = dumps(model, pickle_protocol=pickle.DEFAULT_PROTOCOL)
serialized_highest = dumps(model, pickle_protocol=pickle.HIGHEST_PROTOCOL)
elapsed_default = timeit.timeit(lambda: loads(serialized_default), number=100)
elapsed_highest = timeit.timeit(lambda: loads(serialized_highest), number=100)
self.assertLess(elapsed_highest, elapsed_default)

0 comments on commit 0742ce1

Please sign in to comment.