Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for @bytes #179

Merged
merged 16 commits into from
May 20, 2024
9 changes: 8 additions & 1 deletion fauna/encoding/decoder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import base64
from typing import Any, List, Union

from iso8601 import parse_date

from fauna.query.models import Module, DocumentReference, Document, NamedDocument, NamedDocumentReference, Page, \
NullDocument, StreamToken
NullDocument, StreamToken


class FaunaDecoder:
Expand Down Expand Up @@ -34,6 +35,8 @@ class FaunaDecoder:
+--------------------+---------------+
| None | null |
+--------------------+---------------+
| bytearray | @bytes |
+--------------------+---------------+
| *DocumentReference | @ref |
+--------------------+---------------+
| *Document | @doc |
Expand Down Expand Up @@ -62,6 +65,7 @@ def decode(obj: Any):
- { "@mod": ... } decodes to a Module
- { "@set": ... } decodes to a Page
- { "@stream": ... } decodes to a StreamToken
- { "@bytes": ... } decodes to a bytearray

:param obj: the object to decode
"""
Expand Down Expand Up @@ -103,6 +107,9 @@ def _decode_dict(dct: dict, escaped: bool):
return parse_date(dct["@time"])
if "@date" in dct:
return parse_date(dct["@date"]).date()
if "@bytes" in dct:
bts = base64.b64decode(dct["@bytes"])
return bytearray(bts)
if "@doc" in dct:
value = dct["@doc"]
if isinstance(value, str):
Expand Down
11 changes: 10 additions & 1 deletion fauna/encoding/encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
from datetime import datetime, date
from typing import Any, Optional, List
from typing import Any, Optional, List, Union

from fauna.query.models import DocumentReference, Module, Document, NamedDocument, NamedDocumentReference, NullDocument, StreamToken
from fauna.query.query_builder import Query, Fragment, LiteralFragment, ValueFragment
Expand Down Expand Up @@ -46,6 +47,8 @@ class FaunaEncoder:
+-------------------------------+---------------+
| None | None |
+-------------------------------+---------------+
| bytes / bytearray | @bytes |
+-------------------------------+---------------+
| *Document | @ref |
+-------------------------------+---------------+
| *DocumentReference | @ref |
Expand Down Expand Up @@ -117,6 +120,10 @@ def from_datetime(obj: datetime):
def from_date(obj: date):
return {"@date": obj.isoformat()}

@staticmethod
def from_bytes(obj: Union[bytearray, bytes]):
return {"@bytes": base64.b64encode(obj).decode('ascii')}

@staticmethod
def from_doc_ref(obj: DocumentReference):
return {"@ref": {"id": obj.id, "coll": FaunaEncoder.from_mod(obj.coll)}}
Expand Down Expand Up @@ -185,6 +192,8 @@ def _encode(o: Any, _markers: Optional[List] = None):
return FaunaEncoder.from_datetime(o)
elif isinstance(o, date):
return FaunaEncoder.from_date(o)
elif isinstance(o, bytearray) or isinstance(o, bytes):
return FaunaEncoder.from_bytes(o)
elif isinstance(o, Document):
return FaunaEncoder.from_doc_ref(DocumentReference(o.coll, o.id))
elif isinstance(o, NamedDocument):
Expand Down
14 changes: 14 additions & 0 deletions tests/integration/test_data_type_roundtrips.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from datetime import datetime, timezone, timedelta

from fauna import fql, Document, NamedDocument
Expand Down Expand Up @@ -79,3 +80,16 @@ def test_named_document_roundtrip(client, a_collection):
assert type(test.data) == NamedDocument
result = client.query(fql("${doc}", doc=test.data))
assert test.data == result.data


def test_bytes_roundtrip(client):
test_str = "This is a test string 🚀 with various characters: !@#$%^&*()_+=-`~[]{}|;:'\",./<>?"
test_bytearray = test_str.encode('utf-8')
test = client.query(fql("${bts}", bts=test_bytearray))
assert test.data == test_bytearray
assert test.data.decode('utf-8') == test_str

test_bytes = bytes(test_bytearray)
test = client.query(fql("${bts}", bts=test_bytes))
assert test.data == test_bytearray
assert test.data.decode('utf-8') == test_str
20 changes: 20 additions & 0 deletions tests/unit/test_encoding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import re
from datetime import date, datetime, timezone, timedelta
from typing import Any
Expand Down Expand Up @@ -117,6 +118,25 @@ def test_encode_decode_primitives(subtests):
assert test == decoded


def test_encode_bytes(subtests):
test_str = "This is a test string 🚀 with various characters: !@#$%^&*()_+=-`~[]{}|;:'\",./<>?"
test_bytes = test_str.encode('utf-8')
test_b64_bytes = base64.b64encode(test_bytes)
test_b64_str = test_b64_bytes.decode('utf-8')

with subtests.test(msg="encode/decode bytes"):
encoded = FaunaEncoder.encode(test_bytes)
assert {"@bytes": test_b64_str} == encoded
decoded = FaunaDecoder.decode(encoded)
assert test_bytes == decoded

with subtests.test(msg="encode/decode bytearray"):
encoded = FaunaEncoder.encode(bytearray(test_bytes))
assert {"@bytes": test_b64_str} == encoded
decoded = FaunaDecoder.decode(encoded)
assert test_bytes == decoded


def test_encode_dates_times(subtests):
with subtests.test(msg="encode date into @date"):
test = date(2023, 2, 28)
Expand Down
Loading