Skip to content

Commit

Permalink
fix: do not allow non C-contiguous buffer
Browse files Browse the repository at this point in the history
They were not allowed in the C-extension but given a bug exists in PyPy, they could be accepted and return the wrong result.

Add a workaround for the PyPy bug and disallow non C-contiguous buffer in the pure python fallback implementation.

The `encodebytes` buffer check is also a bit stricter in the C-extension to mimic CPython behavior.
  • Loading branch information
mayeut committed Dec 31, 2023
1 parent 7446ef0 commit 1674a3a
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 67 deletions.
10 changes: 7 additions & 3 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

nox.options.sessions = ["lint", "test"]

ALL_CPYTHON = [f"3.{minor}" for minor in range(6, 12 + 1)]
ALL_PYPY = [f"pypy3.{minor}" for minor in range(8, 10 + 1)]
ALL_PYTHON = ALL_CPYTHON + ALL_PYPY


@nox.session
def lint(session: nox.Session) -> None:
Expand Down Expand Up @@ -46,19 +50,19 @@ def remove_extension(session: nox.Session, in_place: bool = False) -> None:
assert removed


@nox.session(python=["3.6", "3.7", "3.8", "3.9", "3.10"])
@nox.session(python=ALL_PYTHON)
def test(session: nox.Session) -> None:
"""Run tests."""
session.install("-r", "requirements-test.txt")
# make extension mandatory by exporting CIBUILDWHEEL=1
env = {"CIBUILDWHEEL": "1"}
update_env_macos(session, env)
session.install(".", env=env)
session.run("pytest", env=env)
session.run("pytest", *session.posargs, env=env)
# run without extension as well
env.pop("CIBUILDWHEEL")
remove_extension(session)
session.run("pytest", env=env)
session.run("pytest", *session.posargs, env=env)


@nox.session(python=["3.8", "3.11"])
Expand Down
37 changes: 23 additions & 14 deletions src/pybase64/_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def _get_bytes(s: Any) -> Union[bytes, bytearray]:
if isinstance(s, _bytes_types):
return s
try:
return memoryview(s).tobytes()
mv = memoryview(s)
if not mv.c_contiguous:
raise BufferError("memoryview: underlying buffer is not C-contiguous")
return mv.tobytes()
except TypeError:
raise TypeError(
"argument should be a bytes-like object or ASCII "
Expand Down Expand Up @@ -63,24 +66,25 @@ def b64decode(s: Any, altchars: Any = None, validate: bool = False) -> bytes:
A :exc:`binascii.Error` is raised if ``s`` is incorrectly padded.
"""
s = _get_bytes(s)
if altchars is not None:
altchars = _get_bytes(altchars)
if validate:
if len(s) % 4 != 0:
raise BinAsciiError("Incorrect padding")
s = _get_bytes(s)
if altchars is not None:
altchars = _get_bytes(altchars)
assert len(altchars) == 2, repr(altchars)
map = bytes.maketrans(altchars, b"+/")
s = s.translate(map)
result = builtin_decode(s, altchars, validate=False)

# check length of result vs length of input
padding = 0
if len(s) > 1 and s[-2] in (b"=", 61):
padding = padding + 1
if len(s) > 0 and s[-1] in (b"=", 61):
padding = padding + 1
if 3 * (len(s) / 4) - padding != len(result):
expected_len = 0
if len(s) > 0:
padding = 0
# len(s) % 4 != 0 implies len(s) >= 4 here
if s[-2] == 61: # 61 == ord("=")
padding += 1
if s[-1] == 61:
padding += 1
expected_len = 3 * (len(s) // 4) - padding
if expected_len != len(result):
raise BinAsciiError("Non-base64 digit found")
return result
return builtin_decode(s, altchars, validate=False)
Expand Down Expand Up @@ -122,9 +126,11 @@ def b64encode(s: Any, altchars: Any = None) -> bytes:
The result is returned as a :class:`bytes` object.
"""
mv = memoryview(s)
if not mv.c_contiguous:
raise BufferError("memoryview: underlying buffer is not C-contiguous")
if altchars is not None:
altchars = _get_bytes(altchars)
assert len(altchars) == 2, repr(altchars)
return builtin_encode(s, altchars)


Expand All @@ -151,4 +157,7 @@ def encodebytes(s: Any) -> bytes:
The result is returned as a :class:`bytes` object.
"""
mv = memoryview(s)
if not mv.c_contiguous:
raise BufferError("memoryview: underlying buffer is not C-contiguous")
return builtin_encodebytes(s)
38 changes: 32 additions & 6 deletions src/pybase64/_pybase64.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,25 @@ static int libbase64_simd_flag = 0;
static uint32_t active_simd_flag = 0U;
static uint32_t simd_flags;


/* returns 0 on success */
static int get_buffer(PyObject* object, Py_buffer* buffer)
{
if (PyObject_GetBuffer(object, buffer, PyBUF_RECORDS_RO | PyBUF_C_CONTIGUOUS) != 0) {
return -1;
}
#if defined(PYPY_VERSION)
/* PyPy does not respect PyBUF_C_CONTIGUOUS */
if (!PyBuffer_IsContiguous(buffer, 'C')) {
PyBuffer_Release(buffer);
PyErr_Format(PyExc_BufferError, "%R: underlying buffer is not C-contiguous", Py_TYPE(object));
return -1;
}
#endif
return 0;
}


/* returns 0 on success */
static int parse_alphabet(PyObject* alphabetObject, char* alphabet, int* useAlphabet)
{
Expand All @@ -49,7 +68,7 @@ static int parse_alphabet(PyObject* alphabetObject, char* alphabet, int* useAlph
Py_INCREF(alphabetObject);
}

if (PyObject_GetBuffer(alphabetObject, &buffer, PyBUF_SIMPLE) < 0) {
if (get_buffer(alphabetObject, &buffer) != 0) {
Py_DECREF(alphabetObject);
return -1;
}
Expand Down Expand Up @@ -314,7 +333,7 @@ static PyObject* pybase64_encode_impl(PyObject* self, PyObject* args, PyObject *
return NULL;
}

if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) {
if (get_buffer(in_object, &buffer) != 0) {
return NULL;
}

Expand Down Expand Up @@ -434,7 +453,7 @@ static PyObject* pybase64_decode_impl(PyObject* self, PyObject* args, PyObject *
Py_INCREF(in_object);
}
if (source == NULL) {
if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) {
if (get_buffer(in_object, &buffer) != 0) {
Py_DECREF(in_object);
return NULL;
}
Expand Down Expand Up @@ -467,7 +486,7 @@ static PyObject* pybase64_decode_impl(PyObject* self, PyObject* args, PyObject *
Py_DECREF(in_object);
}
in_object = translate_object;
if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) {
if (get_buffer(in_object, &buffer) != 0) {
Py_DECREF(in_object);
return NULL;
}
Expand Down Expand Up @@ -605,10 +624,17 @@ static PyObject* pybase64_encodebytes(PyObject* self, PyObject* in_object)
size_t out_len;
PyObject* out_object;

if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) {
if (get_buffer(in_object, &buffer) != 0) {
return NULL;
}

if (((buffer.format[0] != 'c') && (buffer.format[0] != 'b') && (buffer.format[0] != 'B')) || buffer.format[1] != '\0' ) {
PyBuffer_Release(&buffer);
return PyErr_Format(PyExc_TypeError, "expected single byte elements, not '%s' from %R", buffer.format, Py_TYPE(in_object));
}
if (buffer.ndim != 1) {
PyBuffer_Release(&buffer);
return PyErr_Format(PyExc_TypeError, "expected 1-D data, not %d-D data from %R", buffer.ndim, Py_TYPE(in_object));
}
if (buffer.len > (3 * (PY_SSIZE_T_MAX / 4))) {
PyBuffer_Release(&buffer);
return PyErr_NoMemory();
Expand Down
Loading

0 comments on commit 1674a3a

Please sign in to comment.