Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
0xbe7a committed Aug 23, 2024
1 parent cc1a25d commit a3689d6
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 16 deletions.
8 changes: 6 additions & 2 deletions slim_trees/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@


def dump_sklearn_compressed(
model: Any, file: Union[str, Path, BinaryIO], compression: Optional[Union[str, dict]] = None
model: Any,
file: Union[str, Path, BinaryIO],
compression: Optional[Union[str, dict]] = None,
):
"""
Pickles a model and saves a compressed version to the disk.
Expand Down Expand Up @@ -79,7 +81,9 @@ def dumps_sklearn_compressed(


def dump_lgbm_compressed(
model: Any, file: Union[str, Path, BinaryIO], compression: Optional[Union[str, dict]] = None
model: Any,
file: Union[str, Path, BinaryIO],
compression: Optional[Union[str, dict]] = None,
):
"""
Pickles a model and saves a compressed version to the disk.
Expand Down
4 changes: 2 additions & 2 deletions slim_trees/sklearn_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def dumps(model: Any) -> bytes:
def _tree_pickle(tree: Tree):
assert isinstance(tree, Tree)
reconstructor, args, state = tree.__reduce__()
compressed_state = _compress_tree_state(state) # type: ignore
compressed_state = _compress_tree_state(state) # type: ignore
return _tree_unpickle, (reconstructor, args, (slim_trees_version, compressed_state))


Expand Down Expand Up @@ -113,7 +113,7 @@ def _compress_tree_state(state: Dict) -> Dict:
"values": values,
},
**(
{"missing_go_to_left": np.packbits(missing_go_to_left)} # type: ignore
{"missing_go_to_left": np.packbits(missing_go_to_left)} # type: ignore
if sklearn_version_ge_130
else {}
),
Expand Down
9 changes: 3 additions & 6 deletions tests/test_lgbm_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,11 @@ def test_dump_and_load_from_file(tmp_path, lgbm_regressor):
load_compressed(file, compression="lzma")

# No compression method specified
with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open(
"rb"
) as file:
with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("rb") as file:
load_compressed(file)

with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open(
"wb"
) as file:
with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("wb") as file:
dump_lgbm_compressed(lgbm_regressor, file)


# todo add tests for large models
10 changes: 4 additions & 6 deletions tests/test_sklearn_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,11 @@ def test_dump_and_load_from_file(tmp_path, random_forest_regressor):
load_compressed(file, compression="lzma")

# No compression method specified
with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open(
"rb"
) as file:
with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("rb") as file:
load_compressed(file)

with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open(
"wb"
) as file:
with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("wb") as file:
dump_sklearn_compressed(random_forest_regressor, file)


# todo add tests for large models

0 comments on commit a3689d6

Please sign in to comment.