Skip to content

Commit

Permalink
fix: decoding non C-contiguous buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeut committed Dec 30, 2023
1 parent 7446ef0 commit b3656cf
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 17 deletions.
6 changes: 3 additions & 3 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,19 @@ def remove_extension(session: nox.Session, in_place: bool = False) -> None:
assert removed


@nox.session(python=["3.6", "3.7", "3.8", "3.9", "3.10"])
@nox.session(python=["3.6", "3.7", "3.8", "3.9", "3.10", "3.11", "3.12"])
def test(session: nox.Session) -> None:
"""Run tests."""
session.install("-r", "requirements-test.txt")
# make extension mandatory by exporting CIBUILDWHEEL=1
env = {"CIBUILDWHEEL": "1"}
update_env_macos(session, env)
session.install(".", env=env)
session.run("pytest", env=env)
session.run("pytest", *session.posargs, env=env)
# run without extension as well
env.pop("CIBUILDWHEEL")
remove_extension(session)
session.run("pytest", env=env)
session.run("pytest", *session.posargs, env=env)


@nox.session(python=["3.8", "3.11"])
Expand Down
3 changes: 3 additions & 0 deletions src/pybase64/_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def b64encode(s: Any, altchars: Any = None) -> bytes:
if altchars is not None:
altchars = _get_bytes(altchars)
assert len(altchars) == 2, repr(altchars)
mv = memoryview(s)
if not mv.c_contiguous:
s = mv.tobytes()
return builtin_encode(s, altchars)


Expand Down
52 changes: 43 additions & 9 deletions src/pybase64/_pybase64.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ static uint32_t active_simd_flag = 0U;
static uint32_t simd_flags;

/* returns 0 on success */
static int parse_alphabet(PyObject* alphabetObject, char* alphabet, int* useAlphabet)
static int parse_alphabet(PyObject* alphabetObject, int allow_non_contiguous, char* alphabet, int* useAlphabet)
{
Py_buffer buffer;

Expand All @@ -49,7 +49,7 @@ static int parse_alphabet(PyObject* alphabetObject, char* alphabet, int* useAlph
Py_INCREF(alphabetObject);
}

if (PyObject_GetBuffer(alphabetObject, &buffer, PyBUF_SIMPLE) < 0) {
if (PyObject_GetBuffer(alphabetObject, &buffer, allow_non_contiguous ? PyBUF_FULL_RO : PyBUF_SIMPLE) < 0) {
Py_DECREF(alphabetObject);
return -1;
}
Expand All @@ -61,9 +61,13 @@ static int parse_alphabet(PyObject* alphabetObject, char* alphabet, int* useAlph
return -1;
}

if (PyBuffer_ToContiguous(alphabet, &buffer, 2, 'C') != 0) {
PyBuffer_Release(&buffer);
Py_DECREF(alphabetObject);
return -1;
}

*useAlphabet = 1;
alphabet[0] = ((const char*)buffer.buf)[0];
alphabet[1] = ((const char*)buffer.buf)[1];

if ((alphabet[0] == '+') && (alphabet[1] == '/')) {
*useAlphabet = 0;
Expand Down Expand Up @@ -292,6 +296,31 @@ static int decode_novalidate(const uint8_t *src, size_t srclen, uint8_t *out, si
return 0;
}

static PyObject* get_contiguous_buffer(PyObject* input, Py_buffer* buffer) {
if (PyObject_GetBuffer(input, buffer, PyBUF_FULL_RO) < 0) {
Py_DECREF(input);
return NULL;
}
if (!PyBuffer_IsContiguous(buffer, 'C')) {
int error;
PyObject* contiguous_object = PyBytes_FromStringAndSize(NULL, buffer->len);
if (contiguous_object == NULL) {
PyBuffer_Release(buffer);
Py_DECREF(input);
return NULL;
}
error = PyBuffer_ToContiguous(PyBytes_AS_STRING(contiguous_object), buffer, buffer->len, 'C');
PyBuffer_Release(buffer);
Py_DECREF(input);
input = contiguous_object;
if ((error != 0) || (PyObject_GetBuffer(input, buffer, PyBUF_SIMPLE) != 0)) {
Py_DECREF(input);
return NULL;
}
}
return input;
}

static PyObject* pybase64_encode_impl(PyObject* self, PyObject* args, PyObject *kwds, int return_string)
{
static const char *kwlist[] = { "", "altchars", NULL };
Expand All @@ -310,16 +339,19 @@ static PyObject* pybase64_encode_impl(PyObject* self, PyObject* args, PyObject *
return NULL;
}

if (parse_alphabet(in_alphabet, alphabet, &use_alphabet) != 0) {
if (parse_alphabet(in_alphabet, 1, alphabet, &use_alphabet) != 0) {
return NULL;
}

if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) {
Py_INCREF(in_object);
in_object = get_contiguous_buffer(in_object, &buffer);
if (in_object == NULL) {
return NULL;
}

if (buffer.len > (3 * (PY_SSIZE_T_MAX / 4))) {
PyBuffer_Release(&buffer);
Py_DECREF(in_object);
return PyErr_NoMemory();
}

Expand All @@ -332,6 +364,7 @@ static PyObject* pybase64_encode_impl(PyObject* self, PyObject* args, PyObject *
}
if (out_object == NULL) {
PyBuffer_Release(&buffer);
Py_DECREF(in_object);
return NULL;
}
if (return_string) {
Expand Down Expand Up @@ -375,6 +408,7 @@ static PyObject* pybase64_encode_impl(PyObject* self, PyObject* args, PyObject *
Py_END_ALLOW_THREADS

PyBuffer_Release(&buffer);
Py_DECREF(in_object);

return out_object;
}
Expand Down Expand Up @@ -411,7 +445,7 @@ static PyObject* pybase64_decode_impl(PyObject* self, PyObject* args, PyObject *
return NULL;
}

if (parse_alphabet(in_alphabet, alphabet, &use_alphabet) != 0) {
if (parse_alphabet(in_alphabet, 1, alphabet, &use_alphabet) != 0) {
return NULL;
}

Expand All @@ -434,8 +468,8 @@ static PyObject* pybase64_decode_impl(PyObject* self, PyObject* args, PyObject *
Py_INCREF(in_object);
}
if (source == NULL) {
if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) {
Py_DECREF(in_object);
in_object = get_contiguous_buffer(in_object, &buffer);
if (in_object == NULL) {
return NULL;
}
source = buffer.buf;
Expand Down
62 changes: 57 additions & 5 deletions tests/test_pybase64.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
raise # pragma: no cover
_has_extension = False


STD = 0
URL = 1
ALT1 = 2
Expand Down Expand Up @@ -156,17 +155,25 @@ def simd_setup(simd_id):
param_encode_functions = pytest.mark.parametrize(
"efn, ecast",
[
(pybase64.b64encode, lambda x: x),
(pybase64.b64encode_as_string, lambda x: x.encode("ascii")),
pytest.param(pybase64.b64encode, lambda x: x, id="b64encode"),
pytest.param(
pybase64.b64encode_as_string,
lambda x: x.encode("ascii"),
id="b64encode_as_string",
),
],
)


param_decode_functions = pytest.mark.parametrize(
"dfn, dcast",
[
(pybase64.b64decode, lambda x: x),
(pybase64.b64decode_as_bytearray, lambda x: bytes(x)),
pytest.param(pybase64.b64decode, lambda x: x, id="b64decode"),
pytest.param(
pybase64.b64decode_as_bytearray,
lambda x: bytes(x),
id="b64decode_as_bytearray",
),
],
)

Expand Down Expand Up @@ -460,3 +467,48 @@ def test_flags(request):
"hsw": 1 | 2 | 4 | 8 | 16 | 32 | 64, # AVX2
"spr": 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, # AVX512VBMI
}[cpu] == runtime_flags


@param_encode_functions
def test_enc_multi_dimensional(efn, ecast):
source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV"
vector = memoryview(source).cast("B", (4, len(source) // 4))
assert vector.c_contiguous
test = ecast(efn(vector, None))
base = base64.b64encode(source)
assert test == base


@param_decode_functions
def test_dec_multi_dimensional(dfn, dcast):
source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV"
vector = memoryview(source).cast("B", (4, len(source) // 4))
assert vector.c_contiguous
test = dcast(dfn(vector, None))
base = base64.b64decode(source)
assert test == base


@param_validate
@param_decode_functions
def test_dec_stride(dfn, dcast, validate):
source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV"
source_altchars = b"- _"
vector = memoryview(source)[::2]
vector_altchars = memoryview(source_altchars)[::2]
assert not (vector.contiguous or vector_altchars.contiguous)
test = dcast(dfn(vector, vector_altchars, validate))
base = base64.b64decode(vector.tobytes(), vector_altchars.tobytes(), validate)
assert test == base


@param_encode_functions
def test_enc_stride(efn, ecast):
source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV"
source_altchars = b"- _"
vector = memoryview(source)[::2]
vector_altchars = memoryview(source_altchars)[::2]
assert not (vector.contiguous or vector_altchars.contiguous)
test = ecast(efn(vector, vector_altchars))
base = base64.b64encode(vector.tobytes(), vector_altchars.tobytes())
assert test == base

0 comments on commit b3656cf

Please sign in to comment.