Skip to content

Commit

Permalink
Add support for @bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
pnwpedro committed May 17, 2024
1 parent 7270ac8 commit 5bf577e
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 1 deletion.
9 changes: 8 additions & 1 deletion fauna/encoding/decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import io
from typing import Any, List, Union

from iso8601 import parse_date

from fauna.query.models import Module, DocumentReference, Document, NamedDocument, NamedDocumentReference, Page, \
Expand Down Expand Up @@ -34,6 +35,8 @@ class FaunaDecoder:
+--------------------+---------------+
| None | null |
+--------------------+---------------+
| bytearray | @bytes |
+--------------------+---------------+
| *DocumentReference | @ref |
+--------------------+---------------+
| *Document | @doc |
Expand Down Expand Up @@ -62,6 +65,7 @@ def decode(obj: Any):
- { "@mod": ... } decodes to a Module
- { "@set": ... } decodes to a Page
- { "@stream": ... } decodes to a StreamToken
- { "@bytes": ... } decodes to a bytearray
:param obj: the object to decode
"""
Expand Down Expand Up @@ -103,6 +107,9 @@ def _decode_dict(dct: dict, escaped: bool):
return parse_date(dct["@time"])
if "@date" in dct:
return parse_date(dct["@date"]).date()
if "@bytes" in dct:
bts = base64.b64decode(dct["@bytes"])
return bytearray(bts)
if "@doc" in dct:
value = dct["@doc"]
if isinstance(value, str):
Expand Down
9 changes: 9 additions & 0 deletions fauna/encoding/encoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from datetime import datetime, date
from typing import Any, Optional, List

Expand Down Expand Up @@ -46,6 +47,8 @@ class FaunaEncoder:
+-------------------------------+---------------+
| None | None |
+-------------------------------+---------------+
| bytes / bytearray | @bytes |
+-------------------------------+---------------+
| *Document | @ref |
+-------------------------------+---------------+
| *DocumentReference | @ref |
Expand Down Expand Up @@ -117,6 +120,10 @@ def from_datetime(obj: datetime):
def from_date(obj: date):
return {"@date": obj.isoformat()}

@staticmethod
def from_bytes(obj: bytearray | bytes):
return {"@bytes": base64.b64encode(obj).decode('ascii')}

@staticmethod
def from_doc_ref(obj: DocumentReference):
return {"@ref": {"id": obj.id, "coll": FaunaEncoder.from_mod(obj.coll)}}
Expand Down Expand Up @@ -185,6 +192,8 @@ def _encode(o: Any, _markers: Optional[List] = None):
return FaunaEncoder.from_datetime(o)
elif isinstance(o, date):
return FaunaEncoder.from_date(o)
elif isinstance(o, bytearray) or isinstance(o, bytes):
return FaunaEncoder.from_bytes(o)
elif isinstance(o, Document):
return FaunaEncoder.from_doc_ref(DocumentReference(o.coll, o.id))
elif isinstance(o, NamedDocument):
Expand Down
14 changes: 14 additions & 0 deletions tests/integration/test_data_type_roundtrips.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from datetime import datetime, timezone, timedelta

from fauna import fql, Document, NamedDocument
Expand Down Expand Up @@ -79,3 +80,16 @@ def test_named_document_roundtrip(client, a_collection):
assert type(test.data) == NamedDocument
result = client.query(fql("${doc}", doc=test.data))
assert test.data == result.data


def test_bytes_roundtrip(client):
test_str = "This is a test string 🚀 with various characters: !@#$%^&*()_+=-`~[]{}|;:'\",./<>?"
test_bytearray = test_str.encode('utf-8')
test = client.query(fql("${bts}", bts=test_bytearray))
assert test.data == test_bytearray
assert test.data.decode('utf-8') == test_str

test_bytes = bytes(test_bytearray)
test = client.query(fql("${bts}", bts=test_bytes))
assert test.data == test_bytearray
assert test.data.decode('utf-8') == test_str
20 changes: 20 additions & 0 deletions tests/unit/test_encoding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import re
from datetime import date, datetime, timezone, timedelta
from typing import Any
Expand Down Expand Up @@ -117,6 +118,25 @@ def test_encode_decode_primitives(subtests):
assert test == decoded


def test_encode_bytes(subtests):
test_str = "This is a test string 🚀 with various characters: !@#$%^&*()_+=-`~[]{}|;:'\",./<>?"
test_bytes = test_str.encode('utf-8')
test_b64_bytes = base64.b64encode(test_bytes)
test_b64_str = test_b64_bytes.decode('utf-8')

with subtests.test(msg="encode/decode bytes"):
encoded = FaunaEncoder.encode(test_bytes)
assert {"@bytes": test_b64_str} == encoded
decoded = FaunaDecoder.decode(encoded)
assert test_bytes == decoded

with subtests.test(msg="encode/decode bytearray"):
encoded = FaunaEncoder.encode(bytearray(test_bytes))
assert {"@bytes": test_b64_str} == encoded
decoded = FaunaDecoder.decode(encoded)
assert test_bytes == decoded


def test_encode_dates_times(subtests):
with subtests.test(msg="encode date into @date"):
test = date(2023, 2, 28)
Expand Down

0 comments on commit 5bf577e

Please sign in to comment.