Skip to content

Commit

Permalink
fix: do not allow non C-contiguous buffer
Browse files Browse the repository at this point in the history
They were not allowed in the C-extension but given a bug exists in PyPy, they could be accepted and return the wrong result.

Add a workaround for the PyPy bug and disallow non C-contiguous buffer in the pure python fallback implementation.

The `encodebytes` buffer check is also a bit stricter in the C-extension to mimic CPython behavior.
  • Loading branch information
mayeut committed Dec 31, 2023
1 parent 7446ef0 commit 386eb91
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 30 deletions.
10 changes: 7 additions & 3 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

nox.options.sessions = ["lint", "test"]

ALL_CPYTHON = [f"3.{minor}" for minor in range(6, 12 + 1)]
ALL_PYPY = [f"pypy3.{minor}" for minor in range(8, 10 + 1)]
ALL_PYTHON = ALL_CPYTHON + ALL_PYPY


@nox.session
def lint(session: nox.Session) -> None:
Expand Down Expand Up @@ -46,19 +50,19 @@ def remove_extension(session: nox.Session, in_place: bool = False) -> None:
assert removed


@nox.session(python=["3.6", "3.7", "3.8", "3.9", "3.10"])
@nox.session(python=ALL_PYTHON)
def test(session: nox.Session) -> None:
"""Run tests."""
session.install("-r", "requirements-test.txt")
# make extension mandatory by exporting CIBUILDWHEEL=1
env = {"CIBUILDWHEEL": "1"}
update_env_macos(session, env)
session.install(".", env=env)
session.run("pytest", env=env)
session.run("pytest", *session.posargs, env=env)
# run without extension as well
env.pop("CIBUILDWHEEL")
remove_extension(session)
session.run("pytest", env=env)
session.run("pytest", *session.posargs, env=env)


@nox.session(python=["3.8", "3.11"])
Expand Down
37 changes: 23 additions & 14 deletions src/pybase64/_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def _get_bytes(s: Any) -> Union[bytes, bytearray]:
if isinstance(s, _bytes_types):
return s
try:
return memoryview(s).tobytes()
mv = memoryview(s)
if not mv.c_contiguous:
raise BufferError("memoryview: underlying buffer is not C-contiguous")
return mv.tobytes()
except TypeError:
raise TypeError(
"argument should be a bytes-like object or ASCII "
Expand Down Expand Up @@ -63,24 +66,25 @@ def b64decode(s: Any, altchars: Any = None, validate: bool = False) -> bytes:
A :exc:`binascii.Error` is raised if ``s`` is incorrectly padded.
"""
s = _get_bytes(s)
if altchars is not None:
altchars = _get_bytes(altchars)
if validate:
if len(s) % 4 != 0:
raise BinAsciiError("Incorrect padding")
s = _get_bytes(s)
if altchars is not None:
altchars = _get_bytes(altchars)
assert len(altchars) == 2, repr(altchars)
map = bytes.maketrans(altchars, b"+/")
s = s.translate(map)
result = builtin_decode(s, altchars, validate=False)

# check length of result vs length of input
padding = 0
if len(s) > 1 and s[-2] in (b"=", 61):
padding = padding + 1
if len(s) > 0 and s[-1] in (b"=", 61):
padding = padding + 1
if 3 * (len(s) / 4) - padding != len(result):
expected_len = 0
if len(s) > 0:
padding = 0
# len(s) % 4 != 0 implies len(s) >= 4 here
if s[-2] == 61: # 61 == ord("=")
padding += 1
if s[-1] == 61:
padding += 1
expected_len = 3 * (len(s) // 4) - padding
if expected_len != len(result):
raise BinAsciiError("Non-base64 digit found")
return result
return builtin_decode(s, altchars, validate=False)
Expand Down Expand Up @@ -122,9 +126,11 @@ def b64encode(s: Any, altchars: Any = None) -> bytes:
The result is returned as a :class:`bytes` object.
"""
mv = memoryview(s)
if not mv.c_contiguous:
raise BufferError("memoryview: underlying buffer is not C-contiguous")
if altchars is not None:
altchars = _get_bytes(altchars)
assert len(altchars) == 2, repr(altchars)
return builtin_encode(s, altchars)


Expand All @@ -151,4 +157,7 @@ def encodebytes(s: Any) -> bytes:
The result is returned as a :class:`bytes` object.
"""
mv = memoryview(s)
if not mv.c_contiguous:
raise BufferError("memoryview: underlying buffer is not C-contiguous")
return builtin_encodebytes(s)
38 changes: 32 additions & 6 deletions src/pybase64/_pybase64.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,25 @@ static int libbase64_simd_flag = 0;
static uint32_t active_simd_flag = 0U;
static uint32_t simd_flags;


/* returns 0 on success */
static int get_buffer(PyObject* object, Py_buffer* buffer)
{
if (PyObject_GetBuffer(object, buffer, PyBUF_RECORDS_RO | PyBUF_C_CONTIGUOUS) != 0) {
return -1;
}
#if defined(PYPY_VERSION)
/* PyPy does not respect PyBUF_C_CONTIGUOUS */
if (!PyBuffer_IsContiguous(buffer, 'C')) {
PyBuffer_Release(buffer);
PyErr_Format(PyExc_BufferError, "%R: underlying buffer is not C-contiguous", Py_TYPE(object));
return -1;
}
#endif
return 0;
}


/* returns 0 on success */
static int parse_alphabet(PyObject* alphabetObject, char* alphabet, int* useAlphabet)
{
Expand All @@ -49,7 +68,7 @@ static int parse_alphabet(PyObject* alphabetObject, char* alphabet, int* useAlph
Py_INCREF(alphabetObject);
}

if (PyObject_GetBuffer(alphabetObject, &buffer, PyBUF_SIMPLE) < 0) {
if (get_buffer(alphabetObject, &buffer) != 0) {
Py_DECREF(alphabetObject);
return -1;
}
Expand Down Expand Up @@ -314,7 +333,7 @@ static PyObject* pybase64_encode_impl(PyObject* self, PyObject* args, PyObject *
return NULL;
}

if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) {
if (get_buffer(in_object, &buffer) != 0) {
return NULL;
}

Expand Down Expand Up @@ -434,7 +453,7 @@ static PyObject* pybase64_decode_impl(PyObject* self, PyObject* args, PyObject *
Py_INCREF(in_object);
}
if (source == NULL) {
if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) {
if (get_buffer(in_object, &buffer) != 0) {
Py_DECREF(in_object);
return NULL;
}
Expand Down Expand Up @@ -467,7 +486,7 @@ static PyObject* pybase64_decode_impl(PyObject* self, PyObject* args, PyObject *
Py_DECREF(in_object);
}
in_object = translate_object;
if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) {
if (get_buffer(in_object, &buffer) != 0) {
Py_DECREF(in_object);
return NULL;
}
Expand Down Expand Up @@ -605,10 +624,17 @@ static PyObject* pybase64_encodebytes(PyObject* self, PyObject* in_object)
size_t out_len;
PyObject* out_object;

if (PyObject_GetBuffer(in_object, &buffer, PyBUF_SIMPLE) < 0) {
if (get_buffer(in_object, &buffer) != 0) {
return NULL;
}

if (((buffer.format[0] != 'c') && (buffer.format[0] != 'b') && (buffer.format[0] != 'B')) || buffer.format[1] != '\0' ) {
PyBuffer_Release(&buffer);
return PyErr_Format(PyExc_TypeError, "expected single byte elements, not '%s' from %R", buffer.format, Py_TYPE(in_object));
}
if (buffer.ndim != 1) {
PyBuffer_Release(&buffer);
return PyErr_Format(PyExc_TypeError, "expected 1-D data, not %d-D data from %R", buffer.ndim, Py_TYPE(in_object));
}
if (buffer.len > (3 * (PY_SSIZE_T_MAX / 4))) {
PyBuffer_Release(&buffer);
return PyErr_NoMemory();
Expand Down
71 changes: 64 additions & 7 deletions tests/test_pybase64.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,25 @@ def simd_setup(simd_id):
param_encode_functions = pytest.mark.parametrize(
"efn, ecast",
[
(pybase64.b64encode, lambda x: x),
(pybase64.b64encode_as_string, lambda x: x.encode("ascii")),
pytest.param(pybase64.b64encode, lambda x: x, id="b64encode"),
pytest.param(
pybase64.b64encode_as_string,
lambda x: x.encode("ascii"),
id="b64encode_as_string",
),
],
)


param_decode_functions = pytest.mark.parametrize(
"dfn, dcast",
[
(pybase64.b64decode, lambda x: x),
(pybase64.b64decode_as_bytearray, lambda x: bytes(x)),
pytest.param(pybase64.b64decode, lambda x: x, id="b64decode"),
pytest.param(
pybase64.b64decode_as_bytearray,
lambda x: bytes(x),
id="b64decode_as_bytearray",
),
],
)

Expand Down Expand Up @@ -337,6 +345,7 @@ def test_invalid_padding_dec(dfn, dcast, altchars_id, vector_id, validate, simd)
[b"-__", AssertionError],
[3.0, TypeError],
["-€", ValueError],
[memoryview(b"- _")[::2], BufferError],
]
params_invalid_altchars = pytest.mark.parametrize(
"altchars,exception",
Expand Down Expand Up @@ -373,6 +382,7 @@ def test_invalid_altchars_dec_validate(dfn, dcast, altchars, exception, simd):
[b"A@@@@FG", None, BinAsciiError],
["ABC€", None, ValueError],
[3.0, None, TypeError],
[memoryview(b"ABCDEFGH")[::2], None, BufferError],
]
params_invalid_data_validate = [
[b"\x00\x00\x00\x00", None, BinAsciiError],
Expand Down Expand Up @@ -431,10 +441,37 @@ def test_invalid_data_dec_validate(dfn, dcast, vector, altchars, exception, simd
dfn(vector, altchars, True)


params_invalid_data_enc = [
["this is a test", TypeError],
[memoryview(b"abcd")[::2], BufferError],
]
params_invalid_data_encodebytes = params_invalid_data_enc + [
[memoryview(b"abcd").cast("B", (2, 2)), TypeError],
[memoryview(b"abcd").cast("I"), TypeError],
]
params_invalid_data_enc = pytest.mark.parametrize(
"vector,exception",
params_invalid_data_enc,
ids=[str(i) for i in range(len(params_invalid_data_enc))],
)
params_invalid_data_encodebytes = pytest.mark.parametrize(
"vector,exception",
params_invalid_data_encodebytes,
ids=[str(i) for i in range(len(params_invalid_data_encodebytes))],
)


@params_invalid_data_enc
@param_encode_functions
def test_invalid_data_enc_0(efn, ecast):
with pytest.raises(TypeError):
efn("this is a test")
def test_invalid_data_enc(efn, ecast, vector, exception):
with pytest.raises(exception):
efn(vector)


@params_invalid_data_encodebytes
def test_invalid_data_encodebytes(vector, exception):
with pytest.raises(exception):
pybase64.encodebytes(vector)


@param_encode_functions
Expand All @@ -460,3 +497,23 @@ def test_flags(request):
"hsw": 1 | 2 | 4 | 8 | 16 | 32 | 64, # AVX2
"spr": 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128, # AVX512VBMI
}[cpu] == runtime_flags


@param_encode_functions
def test_enc_multi_dimensional(efn, ecast):
source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV"
vector = memoryview(source).cast("B", (4, len(source) // 4))
assert vector.c_contiguous
test = ecast(efn(vector, None))
base = base64.b64encode(source)
assert test == base


@param_decode_functions
def test_dec_multi_dimensional(dfn, dcast):
source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV"
vector = memoryview(source).cast("B", (4, len(source) // 4))
assert vector.c_contiguous
test = dcast(dfn(vector, None))
base = base64.b64decode(source)
assert test == base

0 comments on commit 386eb91

Please sign in to comment.