Skip to content

Commit

Permalink
Add API function to serialize to strings (#480)
Browse files Browse the repository at this point in the history
* Use streams instead of FILE*

* Add API functions to serialize to strings

* revise docstring
  • Loading branch information
hcho3 authored Apr 27, 2023
1 parent 7502060 commit af7baea
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 123 deletions.
43 changes: 37 additions & 6 deletions include/treelite/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,23 +504,54 @@ TREELITE_DLL int TreeliteQueryNumClass(ModelHandle handle, size_t* out);
*/
TREELITE_DLL int TreeliteSetTreeLimit(ModelHandle handle, size_t limit);

/*!
* \brief Deprecated. Please use \ref TreeliteSerializeModelToFile instead.
*/
TREELITE_DLL int TreeliteSerializeModel(const char* filename, ModelHandle handle);

/*!
* \brief Deprecated. Please use \ref TreeliteDeserializeModelFromFile instead.
*/
TREELITE_DLL int TreeliteDeserializeModel(const char* filename, ModelHandle* out);

/*!
* \brief Serialize (persist) a model object to disk
* \param filename name of the file to which to serialize the model. The file will be using a
* \param handle Handle to the model object
* \param filename Name of the file to which to serialize the model. The file will be using a
* binary format that's optimized to store the Treelite model object efficiently.
* \param handle handle to the model object
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteSerializeModel(const char* filename, ModelHandle handle);
TREELITE_DLL int TreeliteSerializeModelToFile(ModelHandle handle, const char* filename);

/*!
* \brief Deserialize (load) a model object from disk
* \param filename name of the file from which to deserialize the model. The file should be created
* by a call to TreeliteSerializeModel().
* \param filename Name of the file from which to deserialize the model. The file should be created
* by a call to \ref TreeliteSerializeModelToFile.
* \param out Handle to the model object
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDeserializeModel(const char* filename, ModelHandle* out);
TREELITE_DLL int TreeliteDeserializeModelFromFile(const char* filename, ModelHandle* out);

/*!
* \brief Serialize (persist) a model object to a (binary) string.
* \param handle Handle to the model object
* \param out_str String containing serialized model
* \param out_str_len Length of out_str
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteSerializeModelToString(
ModelHandle handle, const char** out_str, size_t* out_str_len);

/*!
* \brief Deserialize (load) a model object from disk
* \param str String containing serialized model. The string should be created by a call to
* \ref TreeliteSerializeModelToString.
* \param str_len Length of str
* \param out Handle to the model object
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDeserializeModelFromString(
const char* str, size_t str_len, ModelHandle* out);

/*!
* \brief Concatenate multiple model objects into a single model object by copying
Expand Down
18 changes: 9 additions & 9 deletions include/treelite/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <algorithm>
#include <map>
#include <memory>
#include <ostream>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
Expand Down Expand Up @@ -382,12 +382,12 @@ class Tree {

inline const char* GetFormatStringForNode();
inline void GetPyBuffer(std::vector<PyBufferFrame>* dest);
inline void SerializeToFile(FILE* dest_fp);
inline void SerializeToStream(std::ostream& os);
// Load a Tree object from a sequence of PyBuffer frames
// Returns the updated position of the cursor in the sequence
inline std::vector<PyBufferFrame>::iterator
InitFromPyBuffer(std::vector<PyBufferFrame>::iterator it);
inline void DeserializeFromFile(FILE* src_fp);
inline void DeserializeFromStream(std::istream& is);

private:
// vector of nodes
Expand Down Expand Up @@ -807,8 +807,8 @@ class Model {
CreateFromPyBuffer(std::vector<PyBufferFrame> frames);

/* Serialization to a file stream */
void SerializeToFile(FILE* dest_fp);
static std::unique_ptr<Model> DeserializeFromFile(FILE* src_fp);
void SerializeToStream(std::ostream& os);
static std::unique_ptr<Model> DeserializeFromStream(std::istream& is);

/*!
* \brief number of features used for the model.
Expand Down Expand Up @@ -840,12 +840,12 @@ class Model {
TypeInfo leaf_output_type_{TypeInfo::kInvalid};
// Internal functions for serialization
virtual void GetPyBuffer(std::vector<PyBufferFrame>* dest) = 0;
virtual void SerializeToFileImpl(FILE* dest_fp) = 0;
virtual void SerializeToStreamImpl(std::ostream& os) = 0;
// Load a Model object from a sequence of PyBuffer frames
// Returns the updated position of the cursor in the sequence
virtual std::vector<PyBufferFrame>::iterator InitFromPyBuffer(
std::vector<PyBufferFrame>::iterator it, std::size_t num_frame) = 0;
virtual void DeserializeFromFileImpl(FILE* src_fp) = 0;
virtual void DeserializeFromStreamImpl(std::istream& is) = 0;
template <typename HeaderPrimitiveFieldHandlerFunc>
inline void SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler);
template <typename HeaderPrimitiveFieldHandlerFunc>
Expand Down Expand Up @@ -878,12 +878,12 @@ class ModelImpl : public Model {
}

inline void GetPyBuffer(std::vector<PyBufferFrame>* dest) override;
inline void SerializeToFileImpl(FILE* dest_fp) override;
inline void SerializeToStreamImpl(std::ostream& os) override;
// Load a ModelImpl object from a sequence of PyBuffer frames
// Returns the updated position of the cursor in the sequence
inline std::vector<PyBufferFrame>::iterator InitFromPyBuffer(
std::vector<PyBufferFrame>::iterator it, std::size_t num_frame) override;
inline void DeserializeFromFileImpl(FILE* src_fp) override;
inline void DeserializeFromStreamImpl(std::istream& is) override;

private:
template <typename HeaderPrimitiveFieldHandlerFunc, typename HeaderCompositeFieldHandlerFunc,
Expand Down
119 changes: 52 additions & 67 deletions include/treelite/tree_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <stdexcept>
#include <iostream>
#include <cstddef>
#include <cstdint>

namespace {

Expand Down Expand Up @@ -457,65 +458,49 @@ inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) {
}

template <typename T>
inline void ReadScalarFromFile(T* scalar, FILE* fp) {
inline void ReadScalarFromStream(T* scalar, std::istream& is) {
static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
if (std::fread(scalar, sizeof(T), 1, fp) < 1) {
throw Error("Could not read a scalar");
}
is.read(reinterpret_cast<char*>(scalar), sizeof(T));
}

template <typename T>
inline void WriteScalarToFile(T* scalar, FILE* fp) {
inline void WriteScalarToStream(T* scalar, std::ostream& os) {
static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
if (std::fwrite(scalar, sizeof(T), 1, fp) < 1) {
throw Error("Could not write a scalar");
}
os.write(reinterpret_cast<const char*>(scalar), sizeof(T));
}

template <typename T>
inline void ReadArrayFromFile(ContiguousArray<T>* vec, FILE* fp) {
uint64_t nelem;
if (std::fread(&nelem, sizeof(nelem), 1, fp) < 1) {
throw Error("Could not read the number of elements");
}
inline void ReadArrayFromStream(ContiguousArray<T>* vec, std::istream& is) {
std::uint64_t nelem;
is.read(reinterpret_cast<char*>(&nelem), sizeof(nelem));
vec->Clear();
vec->Resize(nelem);
if (nelem == 0) {
return; // handle empty arrays
}
const auto nelem_size_t = static_cast<std::size_t>(nelem);
if (std::fread(vec->Data(), sizeof(T), nelem_size_t, fp) < nelem_size_t) {
throw Error("Could not read an array");
}
is.read(reinterpret_cast<char*>(vec->Data()), sizeof(T) * nelem);
}

template <typename T>
inline void WriteArrayToFile(ContiguousArray<T>* vec, FILE* fp) {
static_assert(sizeof(uint64_t) >= sizeof(size_t), "size_t too large");
const auto nelem = static_cast<uint64_t>(vec->Size());
if (std::fwrite(&nelem, sizeof(nelem), 1, fp) < 1) {
throw Error("Could not write the number of elements");
}
inline void WriteArrayToStream(ContiguousArray<T>* vec, std::ostream& os) {
static_assert(sizeof(std::uint64_t) >= sizeof(std::size_t), "size_t too large");
const auto nelem = static_cast<std::uint64_t>(vec->Size());
os.write(reinterpret_cast<const char*>(&nelem), sizeof(nelem));
if (nelem == 0) {
return; // handle empty arrays
}
const auto nelem_size_t = vec->Size();
if (std::fwrite(vec->Data(), sizeof(T), nelem_size_t, fp) < nelem_size_t) {
throw Error("Could not write an array");
}
os.write(reinterpret_cast<const char*>(vec->Data()), sizeof(T) * vec->Size());
}

inline void SkipOptFieldInFile(FILE* fp) {
uint16_t elem_size;
uint64_t nelem;
ReadScalarFromFile(&elem_size, fp);
ReadScalarFromFile(&nelem, fp);
inline void SkipOptFieldInStream(std::istream& is) {
std::uint16_t elem_size;
std::uint64_t nelem;
ReadScalarFromStream(&elem_size, is);
ReadScalarFromStream(&nelem, is);

const uint64_t nbytes = elem_size * nelem;
TREELITE_CHECK_LE(nbytes, std::numeric_limits<long>::max()); // NOLINT
if (std::fseek(fp, static_cast<long>(nbytes), SEEK_CUR) != 0) { // NOLINT
throw Error("Reached end of file");
}
const std::uint64_t nbytes = elem_size * nelem;
TREELITE_CHECK_LE(nbytes, std::numeric_limits<std::streamoff>::max()); // NOLINT
is.seekg(static_cast<std::streamoff>(nbytes), std::ios::cur);
}

template <typename ThresholdType, typename LeafOutputType>
Expand Down Expand Up @@ -628,15 +613,15 @@ Tree<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* des

template <typename ThresholdType, typename LeafOutputType>
inline void
Tree<ThresholdType, LeafOutputType>::SerializeToFile(FILE* dest_fp) {
auto scalar_handler = [dest_fp](auto* field) {
WriteScalarToFile(field, dest_fp);
Tree<ThresholdType, LeafOutputType>::SerializeToStream(std::ostream& os) {
auto scalar_handler = [&os](auto* field) {
WriteScalarToStream(field, os);
};
auto primitive_array_handler = [dest_fp](auto* field) {
WriteArrayToFile(field, dest_fp);
auto primitive_array_handler = [&os](auto* field) {
WriteArrayToStream(field, os);
};
auto composite_array_handler = [dest_fp](auto* field, const char* format) {
WriteArrayToFile(field, dest_fp);
auto composite_array_handler = [&os](auto* field, const char* format) {
WriteArrayToStream(field, os);
};
SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
}
Expand All @@ -660,15 +645,15 @@ Tree<ThresholdType, LeafOutputType>::InitFromPyBuffer(std::vector<PyBufferFrame>

template <typename ThresholdType, typename LeafOutputType>
inline void
Tree<ThresholdType, LeafOutputType>::DeserializeFromFile(FILE* src_fp) {
auto scalar_handler = [src_fp](auto* field) {
ReadScalarFromFile(field, src_fp);
Tree<ThresholdType, LeafOutputType>::DeserializeFromStream(std::istream& is) {
auto scalar_handler = [&is](auto* field) {
ReadScalarFromStream(field, is);
};
auto array_handler = [src_fp](auto* field) {
ReadArrayFromFile(field, src_fp);
auto array_handler = [&is](auto* field) {
ReadArrayFromStream(field, is);
};
auto skip_opt_field_handler = [src_fp]() {
SkipOptFieldInFile(src_fp);
auto skip_opt_field_handler = [&is]() {
SkipOptFieldInStream(is);
};
DeserializeTemplate(scalar_handler, array_handler, skip_opt_field_handler);
}
Expand Down Expand Up @@ -980,16 +965,16 @@ ModelImpl<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>

template <typename ThresholdType, typename LeafOutputType>
inline void
ModelImpl<ThresholdType, LeafOutputType>::SerializeToFileImpl(FILE* dest_fp) {
num_tree_ = static_cast<uint64_t>(this->trees.size());
auto header_primitive_field_handler = [dest_fp](auto* field) {
WriteScalarToFile(field, dest_fp);
ModelImpl<ThresholdType, LeafOutputType>::SerializeToStreamImpl(std::ostream& os) {
num_tree_ = static_cast<std::uint64_t>(this->trees.size());
auto header_primitive_field_handler = [&os](auto* field) {
WriteScalarToStream(field, os);
};
auto header_composite_field_handler = [dest_fp](auto* field, const char* format) {
WriteScalarToFile(field, dest_fp);
auto header_composite_field_handler = [&os](auto* field, const char* format) {
WriteScalarToStream(field, os);
};
auto tree_handler = [dest_fp](Tree<ThresholdType, LeafOutputType>& tree) {
tree.SerializeToFile(dest_fp);
auto tree_handler = [&os](Tree<ThresholdType, LeafOutputType>& tree) {
tree.SerializeToStream(os);
};
header_primitive_field_handler(&num_tree_);
SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
Expand Down Expand Up @@ -1032,19 +1017,19 @@ ModelImpl<ThresholdType, LeafOutputType>::InitFromPyBuffer(

template <typename ThresholdType, typename LeafOutputType>
inline void
ModelImpl<ThresholdType, LeafOutputType>::DeserializeFromFileImpl(FILE* src_fp) {
ReadScalarFromFile(&num_tree_, src_fp);
ModelImpl<ThresholdType, LeafOutputType>::DeserializeFromStreamImpl(std::istream& is) {
ReadScalarFromStream(&num_tree_, is);

auto header_field_handler = [src_fp](auto* field) {
ReadScalarFromFile(field, src_fp);
auto header_field_handler = [&is](auto* field) {
ReadScalarFromStream(field, is);
};

auto skip_opt_field_handler = [src_fp]() {
SkipOptFieldInFile(src_fp);
auto skip_opt_field_handler = [&is]() {
SkipOptFieldInStream(is);
};

auto tree_handler = [src_fp](Tree<ThresholdType, LeafOutputType>& tree) {
tree.DeserializeFromFile(src_fp);
auto tree_handler = [&is](Tree<ThresholdType, LeafOutputType>& tree) {
tree.DeserializeFromStream(is);
};

DeserializeTemplate(num_tree_, header_field_handler, tree_handler, skip_opt_field_handler);
Expand Down
49 changes: 47 additions & 2 deletions python/treelite/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,23 @@ def serialize(self, filename):
filename : :py:class:`str <python:str>`
Path to checkpoint
"""
_check_call(_LIB.TreeliteSerializeModel(c_str(filename), self.handle))
_check_call(_LIB.TreeliteSerializeModelToFile(self.handle, c_str(filename)))

def serialize_bytes(self) -> bytes:
"""
Serialize (persist) the model to a byte sequence, using a fast binary representation.
To recover the model from the byte sequence, use :py:func:`deserialize_bytes` method.
.. note:: Notes on forward and backward compatibility
Please see :doc:`/notes-on-serialization`.
"""
char_ptr_t = ctypes.POINTER(ctypes.c_char)
out_str = char_ptr_t()
out_str_len = ctypes.c_size_t()
_check_call(_LIB.TreeliteSerializeModelToString(
self.handle, ctypes.byref(out_str), ctypes.byref(out_str_len)))
return ctypes.string_at(out_str, out_str_len.value)

def dump_as_json(self, *, pretty_print=True):
"""
Expand Down Expand Up @@ -126,7 +142,36 @@ def deserialize(cls, filename):
Recovered model
"""
handle = ctypes.c_void_p()
_check_call(_LIB.TreeliteDeserializeModel(c_str(filename), ctypes.byref(handle)))
_check_call(_LIB.TreeliteDeserializeModelFromFile(c_str(filename), ctypes.byref(handle)))
return Model(handle)

@classmethod
def deserialize_bytes(cls, model_bytes: bytes) -> Model:
"""
Deserialize (recover) the model from a byte sequence. It is expected that
the byte sequence was generated by a call to the :py:func:`serialize_bytes` method.
.. note:: Notes on forward and backward compatibility
Please see :doc:`/notes-on-serialization`.
Parameters
----------
model_bytes : :py:class:`bytes <python:bytes>`
Byte sequence representing the serialized model
Returns
-------
model : :py:class:`Model` object
Recovered model
"""
handle = ctypes.c_void_p()
model_bytes_len = len(model_bytes)
buffer = ctypes.create_string_buffer(model_bytes, model_bytes_len)
_check_call(_LIB.TreeliteDeserializeModelFromString(
ctypes.POINTER(ctypes.c_char)(buffer),
ctypes.c_size_t(model_bytes_len),
ctypes.byref(handle)))
return Model(handle)

@property
Expand Down
Loading

0 comments on commit af7baea

Please sign in to comment.