Skip to content

Commit

Permalink
fix: decoding non C-contiguous buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeut committed Dec 30, 2023
1 parent 7446ef0 commit b82ed38
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 12 deletions.
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def remove_extension(session: nox.Session, in_place: bool = False) -> None:
assert removed


@nox.session(python=["3.6", "3.7", "3.8", "3.9", "3.10"])
@nox.session(python=["3.6", "3.7", "3.8", "3.9", "3.10", "3.11", "3.12"])
def test(session: nox.Session) -> None:
"""Run tests."""
session.install("-r", "requirements-test.txt")
Expand Down
32 changes: 25 additions & 7 deletions src/pybase64/_pybase64.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ static uint32_t active_simd_flag = 0U;
static uint32_t simd_flags;

/* returns 0 on success */
static int parse_alphabet(PyObject* alphabetObject, char* alphabet, int* useAlphabet)
static int parse_alphabet(PyObject* alphabetObject, int allow_non_contiguous, char* alphabet, int* useAlphabet)
{
Py_buffer buffer;

Expand All @@ -49,7 +49,7 @@ static int parse_alphabet(PyObject* alphabetObject, char* alphabet, int* useAlph
Py_INCREF(alphabetObject);
}

if (PyObject_GetBuffer(alphabetObject, &buffer, PyBUF_SIMPLE) < 0) {
if (PyObject_GetBuffer(alphabetObject, &buffer, allow_non_contiguous ? PyBUF_FULL_RO : PyBUF_SIMPLE) < 0) {
Py_DECREF(alphabetObject);
return -1;
}
Expand All @@ -61,9 +61,13 @@ static int parse_alphabet(PyObject* alphabetObject, char* alphabet, int* useAlph
return -1;
}

if (PyBuffer_ToContiguous(alphabet, &buffer, 2, 'C') != 0) {
PyBuffer_Release(&buffer);
Py_DECREF(alphabetObject);
return -1;
}

*useAlphabet = 1;
alphabet[0] = ((const char*)buffer.buf)[0];
alphabet[1] = ((const char*)buffer.buf)[1];

if ((alphabet[0] == '+') && (alphabet[1] == '/')) {
*useAlphabet = 0;
Expand Down Expand Up @@ -310,7 +314,7 @@ static PyObject* pybase64_encode_impl(PyObject* self, PyObject* args, PyObject *
return NULL;
}

if (parse_alphabet(in_alphabet, alphabet, &use_alphabet) != 0) {
if (parse_alphabet(in_alphabet, 1, alphabet, &use_alphabet) != 0) {
return NULL;
}

Expand Down Expand Up @@ -411,7 +415,7 @@ static PyObject* pybase64_decode_impl(PyObject* self, PyObject* args, PyObject *
return NULL;
}

if (parse_alphabet(in_alphabet, alphabet, &use_alphabet) != 0) {
if (parse_alphabet(in_alphabet, 1, alphabet, &use_alphabet) != 0) {
return NULL;
}

Expand All @@ -434,10 +438,24 @@ static PyObject* pybase64_decode_impl(PyObject* self, PyObject* args, PyObject *
Py_INCREF(in_object);
}
if (source == NULL) {
if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) {
if (PyObject_GetBuffer(in_object, &buffer, PyBUF_FULL_RO) < 0) {
Py_DECREF(in_object);
return NULL;
}
if (!PyBuffer_IsContiguous(&buffer, 'C')) {
PyObject* contiguous_object;
PyBuffer_Release(&buffer);
contiguous_object = PyMemoryView_GetContiguous(in_object, PyBUF_READ, 'C');
Py_DECREF(in_object);
if (contiguous_object == NULL) {
return NULL;
}
in_object = contiguous_object;
if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) {
Py_DECREF(in_object);
return NULL;
}
}
source = buffer.buf;
source_len = buffer.len;
source_use_buffer = 1;
Expand Down
69 changes: 65 additions & 4 deletions tests/test_pybase64.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,25 @@ def simd_setup(simd_id):
param_encode_functions = pytest.mark.parametrize(
"efn, ecast",
[
(pybase64.b64encode, lambda x: x),
(pybase64.b64encode_as_string, lambda x: x.encode("ascii")),
pytest.param(pybase64.b64encode, lambda x: x, id="b64encode"),
pytest.param(
pybase64.b64encode_as_string,
lambda x: x.encode("ascii"),
id="b64encode_as_string",
),
],
)


param_decode_functions = pytest.mark.parametrize(
"dfn, dcast",
[
(pybase64.b64decode, lambda x: x),
(pybase64.b64decode_as_bytearray, lambda x: bytes(x)),
pytest.param(pybase64.b64decode, lambda x: x, id="b64decode"),
pytest.param(
pybase64.b64decode_as_bytearray,
lambda x: bytes(x),
id="b64decode_as_bytearray",
),
],
)

Expand Down Expand Up @@ -460,3 +468,56 @@ def test_flags(request):
"hsw": 1 | 2 | 4 | 8 | 16 | 32 | 64, # AVX2
"spr": 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, # AVX512VBMI
}[cpu] == runtime_flags


@param_encode_functions
def test_enc_multi_dimensional(efn, ecast):
source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV"
vector = memoryview(source).cast("B", (4, len(source) // 4))
assert vector.c_contiguous
test = ecast(efn(vector, None))
base = base64.b64encode(source)
assert test == base


@param_decode_functions
def test_dec_multi_dimensional(dfn, dcast):
source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV"
vector = memoryview(source).cast("B", (4, len(source) // 4))
assert vector.c_contiguous
test = dcast(dfn(vector, None))
base = base64.b64decode(source)
assert test == base


@param_validate
@param_decode_functions
def test_dec_stride(dfn, dcast, validate):
source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV"
source_altchars = b"- _"
vector = memoryview(source)[::2]
vector_altchars = memoryview(source_altchars)[::2]
assert not (vector.contiguous or vector_altchars.contiguous)
test = dcast(dfn(vector, vector_altchars, validate))
base = base64.b64decode(vector.tobytes(), vector_altchars.tobytes(), validate)
assert test == base


@param_encode_functions
def test_enc_stride_data_not_supported(efn, ecast):
source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV"
vector = memoryview(source)[::2]
assert not vector.contiguous
with pytest.raises(BufferError):
efn(vector, None)


@param_encode_functions
def test_enc_stride_altchars(efn, ecast):
source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV"
source_altchars = b"- _"
vector_altchars = memoryview(source_altchars)[::2]
assert not vector_altchars.contiguous
test = ecast(efn(source, vector_altchars))
base = base64.b64encode(source, vector_altchars.tobytes())
assert test == base

0 comments on commit b82ed38

Please sign in to comment.