Skip to content

Commit

Permalink
Add error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
edenhaus committed Dec 26, 2024
1 parent 58e7929 commit 16c71cb
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 9 deletions.
27 changes: 19 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
use std::error::Error;

use base64::{engine::general_purpose, Engine as _};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

fn _decompress_7z_base64_data(input: String) -> Result<String, Box<dyn Error>> {
let mut bytes = general_purpose::STANDARD.decode(input)?;

// Insert required 0 bytes
for _ in 0..=3 {
bytes.insert(8, 0);
}

let decompressed = lzma::decompress(&bytes)?;
Ok(String::from_utf8(decompressed)?)
}

/// Decompress base64 decoded 7z compressed string.
#[pyfunction]
fn decompress_7z_base64_data(input: String) -> PyResult<String> {
fn decompress_7z_base64_data(input: String) -> Result<String, PyErr> {
// todo add error handling
let mut bytes = general_purpose::STANDARD.decode(input).unwrap();
bytes.insert(8, 0);
bytes.insert(8, 0);
bytes.insert(8, 0);
bytes.insert(8, 0);
let decompressed = lzma::decompress(&bytes).unwrap();
Ok(String::from_utf8(decompressed).unwrap())
Ok(_decompress_7z_base64_data(input).map_err(|err| PyValueError::new_err(err.to_string()))?)
}

/// Deebot client written in Rust
#[pymodule]
fn rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(decompress_7z_base64_data, m)?)?;
Expand Down
2 changes: 1 addition & 1 deletion tests/commands/json/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ async def test_getMapSetV2_virtual_walls() -> None:
"batid": "gheijg",
"serial": 1,
"index": 1,
"subsets": "XQAABADHAAAAAC2WwEHwYhHX3vWwDK80QCnaQU0mwUd9Vk34ub6OxzOk6kdFfbFvpVp4iIlKisAvp0MznQNYEZ8koxFHnO,+iM44GUKgujGQKgzl0bScbQgaon1jI3eyCRikWlkmrbwA=",
"subsets": "XQAABADHAAAAAC2WwEHwYhHX3vWwDK80QCnaQU0mwUd9Vk34ub6OxzOk6kdFfbFvpVp4iIlKisAvp0MznQNYEZ8koxFHnO+iM44GUKgujGQKgzl0bScbQgaon1jI3eyCRikWlkmrbwA=",
"infoSize": 199,
}
)
Expand Down
49 changes: 49 additions & 0 deletions tests/test_rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@

from __future__ import annotations

import base64
import lzma
from typing import TYPE_CHECKING, Any

import pytest

from deebot_client.rs import decompress_7z_base64_data

if TYPE_CHECKING:
from contextlib import AbstractContextManager


@pytest.mark.parametrize(
("input", "expected"),
Expand All @@ -18,8 +25,50 @@
"XQAABABBAAAAAC2WwEIwUhHX3vfFDfs1H1PUqtdWgakwVnMBz3Bb3yaoE5OYkdYA",
'[["4","-6217","3919","-6217","231","-2642","231","-2642","3919"]]',
),
(
"XQAABADHAAAAAC2WwEHwYhHX3vWwDK80QCnaQU0mwUd9Vk34ub6OxzOk6kdFfbFvpVp4iIlKisAvp0MznQNYEZ8koxFHnO+iM44GUKgujGQKgzl0bScbQgaon1jI3eyCRikWlkmrbwA=",
'[["0","-5195","-1059","-5195","-37","-5806","-37","-5806","-1059"],["1","-7959","220","-7959","1083","-9254","1083","-9254","220"],["2","-9437","347","-5387","410"],["3","-5667","317","-4888","-56"]]',
),
],
)
def test_decompress_7z_base64_data(input: str, expected: str) -> None:
"""Test decompress_7z_base64_data function."""
assert _decompress_7z_base64_data_python(input) == expected
assert decompress_7z_base64_data(input) == expected


@pytest.mark.parametrize(
("input", "error"),
[
(
"XQAABADHAAAAAC2WwEHwYhHX3vWwDK80QCnaQU0mwUd9Vk34ub6OxzOk6kdFfbFvpVp4iIlKisAvp0MznQNYEZ8koxFHnO,+iM44GUKgujGQKgzl0bScbQgaon1jI3eyCRikWlkmrbwA=",
pytest.raises(ValueError, match="Invalid symbol 44, offset 94."),
),
(
"XQAABABBAAAAAC2WwEIwUhHX3vfFDfs1H1PUqtdWgakwVnMBz3Bb3yaoE5OYkd",
pytest.raises(ValueError, match="Invalid padding"),
),
],
)
def test_decompress_7z_base64_data_errors(
input: str, error: AbstractContextManager[Any]
) -> None:
"""Test decompress_7z_base64_data function."""
with error:
assert decompress_7z_base64_data(input)


def _decompress_7z_base64_data_python(data: str) -> str:
"""Decompress base64 decoded 7z compressed string."""
final_array = bytearray()

# Decode Base64
decoded = base64.b64decode(data)

for i, idx in enumerate(decoded):
if i == 8:
final_array.extend(b"\x00\x00\x00\x00")
final_array.append(idx)

dec = lzma.LZMADecompressor(lzma.FORMAT_AUTO, None, None)
return dec.decompress(final_array).decode()

0 comments on commit 16c71cb

Please sign in to comment.