Skip to content

Commit

Permalink
add info_bytes and get_shape functions (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
JIy3AHKO authored Jul 25, 2024
1 parent 0659481 commit c6e67e1
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 5 deletions.
47 changes: 42 additions & 5 deletions pyfastmask/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pathlib
from typing import Dict, Union
from typing import Dict, Tuple, Union

import numpy as _np

Expand All @@ -26,7 +26,7 @@ def read(path: Union[str, pathlib.Path]) -> _np.ndarray:
Read pfm file and return mask as numpy array.
Args:
path: (str, pathlib.Path) Path to pfm file.
path: (Union[str, pathlib.Path]) Path to pfm file.
Returns:
mask: (np.ndarray) Mask as numpy array.
Expand Down Expand Up @@ -61,7 +61,7 @@ def write(path: Union[str, pathlib.Path], mask: _np.ndarray) -> None:
Mask must be numpy array with shape (W, H) or (W, H, 1) and dtype np.uint8.
Args:
path: (str) Path to save pfm file.
path: (Union[str, pathlib.Path]) Path to save pfm file.
mask: (np.ndarray) Mask to save.
Examples:
Expand Down Expand Up @@ -95,7 +95,7 @@ def encode(mask: _np.ndarray) -> bytes:
return _pyfastmask.encode(mask)


def info(path: Union[str, pathlib.Path]) -> Dict[str, int]:
def info(path: Union[str, pathlib.Path, bytes]) -> Dict[str, Union[int, Tuple[int, int]]]:
"""
Get info of pfm file.
Expand All @@ -107,10 +107,47 @@ def info(path: Union[str, pathlib.Path]) -> Dict[str, int]:
- symbol_bit_width: (int) Number of bits used to represent symbol value.
Args:
path: (str) Path to pfm file.
path: (Union[str, pathlib.Path]) Path to pfm file.
Returns:
info: (Dict[str, int]) Dictionary with file info.
"""
path = str(path)
return _pyfastmask.info(path)


def info_bytes(encoded_pfm: bytes) -> Dict[str, Union[int, Tuple[int, int]]]:
"""
Get info of encoded pfm file.
See :func: `pyfastmask.info` for header description.
Args:
encoded_pfm: (bytes) encoded pfm file.
Returns:
info: (Dict[str, int]) Dictionary with file info.
"""

return _pyfastmask.info_buffer(encoded_pfm)


def get_shape(pfm: Union[str, pathlib.Path, bytes]) -> Tuple[int, int]:
"""
Get shape of pfm file.
Convenient function to get shape from pfm file given as a path on disk or an encoded buffer.
Args:
pfm: (Union[str, pathlib.Path, bytes]) pfm file. In case of bytes, it is treated as encoded data; otherwise it's
treated as path.
Returns:
shape: (int, int) shape of an image in pfm file.
"""

if isinstance(pfm, bytes):
info_data = info_bytes(pfm)
else:
info_data = info(pfm)

return info_data['shape']
22 changes: 22 additions & 0 deletions src/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,27 @@ py::dict read_header_from_file(const std::string& filename) {
"shape"_a = py::make_tuple(header.mask_height, header.mask_width)
);
}



py::dict read_header_from_buffer(const py::buffer& data_bytes) {
py::buffer_info info = data_bytes.request();

const char* buffer = static_cast<const char*>(info.ptr);

validate_buffer_size(info.size);

Header header = read_header(buffer);
validate_header(header);

return py::dict(
"symbol_bit_width"_a = header.symbol_bit_width,
"count_bit_width"_a = header.count_bit_width,
"unique_symbols_count"_a = header.unique_symbols_count,
"line_count_bit_width"_a = header.line_count_bit_width,
"shape"_a = py::make_tuple(header.mask_height, header.mask_width)
);
}



Expand All @@ -157,6 +178,7 @@ PYBIND11_MODULE(_pyfastmask, m) {
m.def("decode", &read_mask_from_buffer, "Decodes mask from buffer", py::arg("buffer"), py::return_value_policy::move);

m.def("info", &read_header_from_file, "Read mask header from file", py::arg("filename"));
m.def("info_buffer", &read_header_from_buffer, "Read mask header from buffer", py::arg("buffer"));

py::class_<Header>(m, "Header")
.def_readonly("symbol_bit_width", &Header::symbol_bit_width)
Expand Down
42 changes: 42 additions & 0 deletions tests/test_fastmask.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,22 @@ def test_info_pathlib_path_successful(self):
info = pf.info(path)
self.assertEqual(info['shape'], (256, 128))

def test_info_buffer_returns_correct_shape(self):
mask = np.random.randint(0, 256, (12, 34), dtype=np.uint8)

encoded = pf.encode(mask)
info = pf.info_bytes(encoded)

self.assertEqual(info['shape'], (12, 34))

def test_info_buffer_small_buffer_produces_error(self):
with self.assertRaises(ValueError):
pf.info_bytes(b'0')

def test_info_buffer_wrong_magic_bytes_produces_error(self):
with self.assertRaises(ValueError):
pf.info_bytes(b'0123456789ancdefghigklmnop')


class TestSymbolBitWidth(unittest.TestCase):
def test_info_for_binary_image_returns_1bits_symbol_bit_width(self):
Expand All @@ -193,3 +209,29 @@ def test_info_for_256color_image_returns_8bits_symbol_bit_width(self):
pf.write(f, mask)
info = pf.info(f)
self.assertEqual(info['symbol_bit_width'], 8)


class TestGetShape(unittest.TestCase):
def test_get_shape_returns_correct_shape_for_str(self):
mask = np.random.randint(0, 256, (256, 128), dtype=np.uint8)
with TempFile() as f:
pf.write(f, mask)
shape = pf.get_shape(f)

self.assertEqual(shape, (256, 128))

def test_get_shape_returns_correct_shape_for_pathlib(self):
mask = np.random.randint(0, 256, (256, 128), dtype=np.uint8)
with TempFile() as f:
f = pathlib.Path(f)
pf.write(f, mask)
shape = pf.get_shape(f)

self.assertEqual(shape, (256, 128))

def test_get_shape_returns_correct_shape_for_bytes(self):
mask = np.random.randint(0, 256, (256, 128), dtype=np.uint8)
encoded = pf.encode(mask)
shape = pf.get_shape(encoded)

self.assertEqual(shape, (256, 128))

0 comments on commit c6e67e1

Please sign in to comment.