From b82ed38bd91a28b7939916dab3081512375e7945 Mon Sep 17 00:00:00 2001 From: mayeut Date: Fri, 29 Dec 2023 19:41:33 +0100 Subject: [PATCH] fix: decoding non C-contiguous buffer --- noxfile.py | 2 +- src/pybase64/_pybase64.c | 32 +++++++++++++++---- tests/test_pybase64.py | 69 +++++++++++++++++++++++++++++++++++++--- 3 files changed, 91 insertions(+), 12 deletions(-) diff --git a/noxfile.py b/noxfile.py index 3830ac54..85db2faa 100644 --- a/noxfile.py +++ b/noxfile.py @@ -46,7 +46,7 @@ def remove_extension(session: nox.Session, in_place: bool = False) -> None: assert removed -@nox.session(python=["3.6", "3.7", "3.8", "3.9", "3.10"]) +@nox.session(python=["3.6", "3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]) def test(session: nox.Session) -> None: """Run tests.""" session.install("-r", "requirements-test.txt") diff --git a/src/pybase64/_pybase64.c b/src/pybase64/_pybase64.c index 63865e95..08799fd6 100644 --- a/src/pybase64/_pybase64.c +++ b/src/pybase64/_pybase64.c @@ -25,7 +25,7 @@ static uint32_t active_simd_flag = 0U; static uint32_t simd_flags; /* returns 0 on success */ -static int parse_alphabet(PyObject* alphabetObject, char* alphabet, int* useAlphabet) +static int parse_alphabet(PyObject* alphabetObject, int allow_non_contiguous, char* alphabet, int* useAlphabet) { Py_buffer buffer; @@ -49,7 +49,7 @@ static int parse_alphabet(PyObject* alphabetObject, char* alphabet, int* useAlph Py_INCREF(alphabetObject); } - if (PyObject_GetBuffer(alphabetObject, &buffer, PyBUF_SIMPLE) < 0) { + if (PyObject_GetBuffer(alphabetObject, &buffer, allow_non_contiguous ? PyBUF_FULL_RO : PyBUF_SIMPLE) < 0) { Py_DECREF(alphabetObject); return -1; } @@ -61,9 +61,13 @@ static int parse_alphabet(PyObject* alphabetObject, char* alphabet, int* useAlph return -1; } + if (PyBuffer_ToContiguous(alphabet, &buffer, 2, 'C') != 0) { + PyBuffer_Release(&buffer); + Py_DECREF(alphabetObject); + return -1; + } + *useAlphabet = 1; - alphabet[0] = ((const char*)buffer.buf)[0]; - alphabet[1] = ((const char*)buffer.buf)[1]; if ((alphabet[0] == '+') && (alphabet[1] == '/')) { *useAlphabet = 0; @@ -310,7 +314,7 @@ static PyObject* pybase64_encode_impl(PyObject* self, PyObject* args, PyObject * return NULL; } - if (parse_alphabet(in_alphabet, alphabet, &use_alphabet) != 0) { + if (parse_alphabet(in_alphabet, 1, alphabet, &use_alphabet) != 0) { return NULL; } @@ -411,7 +415,7 @@ static PyObject* pybase64_decode_impl(PyObject* self, PyObject* args, PyObject * return NULL; } - if (parse_alphabet(in_alphabet, alphabet, &use_alphabet) != 0) { + if (parse_alphabet(in_alphabet, 1, alphabet, &use_alphabet) != 0) { return NULL; } @@ -434,10 +438,24 @@ static PyObject* pybase64_decode_impl(PyObject* self, PyObject* args, PyObject * Py_INCREF(in_object); } if (source == NULL) { - if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) { + if (PyObject_GetBuffer(in_object, &buffer, PyBUF_FULL_RO) < 0) { Py_DECREF(in_object); return NULL; } + if (!PyBuffer_IsContiguous(&buffer, 'C')) { + PyObject* contiguous_object; + PyBuffer_Release(&buffer); + contiguous_object = PyMemoryView_GetContiguous(in_object, PyBUF_READ, 'C'); + Py_DECREF(in_object); + if (contiguous_object == NULL) { + return NULL; + } + in_object = contiguous_object; + if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) { + Py_DECREF(in_object); + return NULL; + } + } source = buffer.buf; source_len = buffer.len; source_use_buffer = 1; diff --git a/tests/test_pybase64.py b/tests/test_pybase64.py index 2d817891..6d78d4a3 100644 --- a/tests/test_pybase64.py +++ b/tests/test_pybase64.py @@ -156,8 +156,12 @@ def simd_setup(simd_id): param_encode_functions = pytest.mark.parametrize( "efn, ecast", [ - (pybase64.b64encode, lambda x: x), - (pybase64.b64encode_as_string, lambda x: x.encode("ascii")), + pytest.param(pybase64.b64encode, lambda x: x, id="b64encode"), + pytest.param( + pybase64.b64encode_as_string, + lambda x: x.encode("ascii"), + id="b64encode_as_string", + ), ], ) @@ -165,8 +169,12 @@ def simd_setup(simd_id): param_decode_functions = pytest.mark.parametrize( "dfn, dcast", [ - (pybase64.b64decode, lambda x: x), - (pybase64.b64decode_as_bytearray, lambda x: bytes(x)), + pytest.param(pybase64.b64decode, lambda x: x, id="b64decode"), + pytest.param( + pybase64.b64decode_as_bytearray, + lambda x: bytes(x), + id="b64decode_as_bytearray", + ), ], ) @@ -460,3 +468,56 @@ def test_flags(request): "hsw": 1 | 2 | 4 | 8 | 16 | 32 | 64, # AVX2 "spr": 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, # AVX512VBMI }[cpu] == runtime_flags + + +@param_encode_functions +def test_enc_multi_dimensional(efn, ecast): + source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV" + vector = memoryview(source).cast("B", (4, len(source) // 4)) + assert vector.c_contiguous + test = ecast(efn(vector, None)) + base = base64.b64encode(source) + assert test == base + + +@param_decode_functions +def test_dec_multi_dimensional(dfn, dcast): + source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV" + vector = memoryview(source).cast("B", (4, len(source) // 4)) + assert vector.c_contiguous + test = dcast(dfn(vector, None)) + base = base64.b64decode(source) + assert test == base + + +@param_validate +@param_decode_functions +def test_dec_stride(dfn, dcast, validate): + source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV" + source_altchars = b"- _" + vector = memoryview(source)[::2] + vector_altchars = memoryview(source_altchars)[::2] + assert not (vector.contiguous or vector_altchars.contiguous) + test = dcast(dfn(vector, vector_altchars, validate)) + base = base64.b64decode(vector.tobytes(), vector_altchars.tobytes(), validate) + assert test == base + + +@param_encode_functions +def test_enc_stride_data_not_supported(efn, ecast): + source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV" + vector = memoryview(source)[::2] + assert not vector.contiguous + with pytest.raises(BufferError): + efn(vector, None) + + +@param_encode_functions +def test_enc_stride_altchars(efn, ecast): + source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV" + source_altchars = b"- _" + vector_altchars = memoryview(source_altchars)[::2] + assert not vector_altchars.contiguous + test = ecast(efn(source, vector_altchars)) + base = base64.b64encode(source, vector_altchars.tobytes()) + assert test == base