From a7899b64a50fa035caf44705cd8acec1a00b8e45 Mon Sep 17 00:00:00 2001 From: Andrew Luzan Date: Tue, 4 Jun 2024 13:12:19 +0300 Subject: [PATCH] correctly estimate symbol bit width for cases with 2,4,8,16... unique symbols --- src/encode.h | 2 +- tests/test_fastmask.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/encode.h b/src/encode.h index 241799d..73d976e 100644 --- a/src/encode.h +++ b/src/encode.h @@ -235,7 +235,7 @@ std::vector get_unique_symbols(const std::vector& rle_li } uint8_t estimate_symbol_bit_width(const std::vector& unique_symbols) { - return get_bit_width(unique_symbols.size()); + return get_bit_width(unique_symbols.size() - 1); } diff --git a/tests/test_fastmask.py b/tests/test_fastmask.py index dc5d289..d98b0f3 100644 --- a/tests/test_fastmask.py +++ b/tests/test_fastmask.py @@ -158,3 +158,19 @@ def test_info_on_small_file_produces_error(self): with self.assertRaises(ValueError): pf.info(f.name) + + +class TestSymbolBitWidth(unittest.TestCase): + def test_info_for_binary_image_returns_1bits_symbol_bit_width(self): + mask = np.array([[0, 1], [1, 0]], dtype=np.uint8) + with TempFile() as f: + pf.write(f, mask) + info = pf.info(f) + self.assertEqual(info['symbol_bit_width'], 1) + + def test_info_for_256color_image_returns_8bits_symbol_bit_width(self): + mask = np.arange(256, dtype=np.uint8).reshape(16, 16) + with TempFile() as f: + pf.write(f, mask) + info = pf.info(f) + self.assertEqual(info['symbol_bit_width'], 8)