diff --git a/fauna/encoding/decoder.py b/fauna/encoding/decoder.py index 5e1d9426..0536be39 100644 --- a/fauna/encoding/decoder.py +++ b/fauna/encoding/decoder.py @@ -1,9 +1,10 @@ +import base64 from typing import Any, List, Union from iso8601 import parse_date from fauna.query.models import Module, DocumentReference, Document, NamedDocument, NamedDocumentReference, Page, \ - NullDocument, StreamToken + NullDocument, StreamToken class FaunaDecoder: @@ -34,6 +35,8 @@ class FaunaDecoder: +--------------------+---------------+ | None | null | +--------------------+---------------+ + | bytearray | @bytes | + +--------------------+---------------+ | *DocumentReference | @ref | +--------------------+---------------+ | *Document | @doc | @@ -62,6 +65,7 @@ def decode(obj: Any): - { "@mod": ... } decodes to a Module - { "@set": ... } decodes to a Page - { "@stream": ... } decodes to a StreamToken + - { "@bytes": ... } decodes to a bytearray :param obj: the object to decode """ @@ -103,6 +107,9 @@ def _decode_dict(dct: dict, escaped: bool): return parse_date(dct["@time"]) if "@date" in dct: return parse_date(dct["@date"]).date() + if "@bytes" in dct: + bts = base64.b64decode(dct["@bytes"]) + return bytearray(bts) if "@doc" in dct: value = dct["@doc"] if isinstance(value, str): diff --git a/fauna/encoding/encoder.py b/fauna/encoding/encoder.py index f2e12b85..bac96e93 100644 --- a/fauna/encoding/encoder.py +++ b/fauna/encoding/encoder.py @@ -1,5 +1,6 @@ +import base64 from datetime import datetime, date -from typing import Any, Optional, List +from typing import Any, Optional, List, Union from fauna.query.models import DocumentReference, Module, Document, NamedDocument, NamedDocumentReference, NullDocument, StreamToken from fauna.query.query_builder import Query, Fragment, LiteralFragment, ValueFragment @@ -46,6 +47,8 @@ class FaunaEncoder: +-------------------------------+---------------+ | None | None | +-------------------------------+---------------+ + | bytes / bytearray | @bytes | + +-------------------------------+---------------+ | *Document | @ref | +-------------------------------+---------------+ | *DocumentReference | @ref | @@ -117,6 +120,10 @@ def from_datetime(obj: datetime): def from_date(obj: date): return {"@date": obj.isoformat()} + @staticmethod + def from_bytes(obj: Union[bytearray, bytes]): + return {"@bytes": base64.b64encode(obj).decode('ascii')} + @staticmethod def from_doc_ref(obj: DocumentReference): return {"@ref": {"id": obj.id, "coll": FaunaEncoder.from_mod(obj.coll)}} @@ -185,6 +192,8 @@ def _encode(o: Any, _markers: Optional[List] = None): return FaunaEncoder.from_datetime(o) elif isinstance(o, date): return FaunaEncoder.from_date(o) + elif isinstance(o, bytearray) or isinstance(o, bytes): + return FaunaEncoder.from_bytes(o) elif isinstance(o, Document): return FaunaEncoder.from_doc_ref(DocumentReference(o.coll, o.id)) elif isinstance(o, NamedDocument): diff --git a/tests/integration/test_data_type_roundtrips.py b/tests/integration/test_data_type_roundtrips.py index 8943cafc..0a09e73c 100644 --- a/tests/integration/test_data_type_roundtrips.py +++ b/tests/integration/test_data_type_roundtrips.py @@ -1,3 +1,4 @@ +import base64 from datetime import datetime, timezone, timedelta from fauna import fql, Document, NamedDocument @@ -79,3 +80,16 @@ def test_named_document_roundtrip(client, a_collection): assert type(test.data) == NamedDocument result = client.query(fql("${doc}", doc=test.data)) assert test.data == result.data + + +def test_bytes_roundtrip(client): + test_str = "This is a test string 🚀 with various characters: !@#$%^&*()_+=-`~[]{}|;:'\",./<>?" + test_bytearray = test_str.encode('utf-8') + test = client.query(fql("${bts}", bts=test_bytearray)) + assert test.data == test_bytearray + assert test.data.decode('utf-8') == test_str + + test_bytes = bytes(test_bytearray) + test = client.query(fql("${bts}", bts=test_bytes)) + assert test.data == test_bytearray + assert test.data.decode('utf-8') == test_str diff --git a/tests/unit/test_encoding.py b/tests/unit/test_encoding.py index e54d29cb..2edcdb43 100644 --- a/tests/unit/test_encoding.py +++ b/tests/unit/test_encoding.py @@ -1,3 +1,4 @@ +import base64 import re from datetime import date, datetime, timezone, timedelta from typing import Any @@ -117,6 +118,25 @@ def test_encode_decode_primitives(subtests): assert test == decoded +def test_encode_bytes(subtests): + test_str = "This is a test string 🚀 with various characters: !@#$%^&*()_+=-`~[]{}|;:'\",./<>?" + test_bytes = test_str.encode('utf-8') + test_b64_bytes = base64.b64encode(test_bytes) + test_b64_str = test_b64_bytes.decode('utf-8') + + with subtests.test(msg="encode/decode bytes"): + encoded = FaunaEncoder.encode(test_bytes) + assert {"@bytes": test_b64_str} == encoded + decoded = FaunaDecoder.decode(encoded) + assert test_bytes == decoded + + with subtests.test(msg="encode/decode bytearray"): + encoded = FaunaEncoder.encode(bytearray(test_bytes)) + assert {"@bytes": test_b64_str} == encoded + decoded = FaunaDecoder.decode(encoded) + assert test_bytes == decoded + + def test_encode_dates_times(subtests): with subtests.test(msg="encode date into @date"): test = date(2023, 2, 28)