Skip to content

Commit

Permalink
Initial work to support streams
Browse files Browse the repository at this point in the history
  • Loading branch information
marrony committed Feb 29, 2024
1 parent 196bf41 commit 42f0dee
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 6 deletions.
23 changes: 23 additions & 0 deletions fauna/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fauna.client.headers import _DriverEnvironment, _Header, _Auth, Header
from fauna.http.http_client import HTTPClient
from fauna.query import Query, Page, fql
from fauna.query.models import StreamToken
from fauna.client.utils import _Environment, LastTxnTs
from fauna.encoding import FaunaEncoder, FaunaDecoder
from fauna.encoding import QuerySuccess, ConstraintFailure, QueryTags, QueryStats
Expand Down Expand Up @@ -375,6 +376,28 @@ def _query(
schema_version=schema_version,
)

def stream(self, token: StreamToken):
# todo: pass a token or a Query

if not isinstance(token, StreamToken):
err_msg = f"'token' must be a StreamToken but was a {type(token)}."
raise TypeError(err_msg)

headers = self._headers.copy()
headers[_Header.Format] = "tagged"
headers[_Header.Authorization] = self._auth.bearer()

data = {"token": token.token}

response = self._session.stream(
url=self._endpoint + "/stream/1",
headers=headers,
data=data,
)

for line in response:
yield FaunaDecoder.decode(line)

def _check_protocol(self, response_json: Any, status_code):
# TODO: Logic to validate wire protocol belongs elsewhere.
should_raise = False
Expand Down
8 changes: 7 additions & 1 deletion fauna/encoding/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from iso8601 import parse_date

from fauna.query.models import Module, DocumentReference, Document, NamedDocument, NamedDocumentReference, Page, \
NullDocument
NullDocument, StreamToken


class FaunaDecoder:
Expand Down Expand Up @@ -42,6 +42,8 @@ class FaunaDecoder:
+--------------------+---------------+
| Page | @set |
+--------------------+---------------+
| StreamToken | @stream |
+--------------------+---------------+
"""

Expand All @@ -59,6 +61,7 @@ def decode(obj: Any):
- { "@ref": ... } decodes to a DocumentReference or NamedDocumentReference
- { "@mod": ... } decodes to a Module
- { "@set": ... } decodes to a Page
- { "@stream": ... } decodes to a StreamToken
:param obj: the object to decode
"""
Expand Down Expand Up @@ -165,4 +168,7 @@ def _decode_dict(dct: dict, escaped: bool):

return Page(data=data, after=after)

if "@stream" in dct:
return StreamToken(dct["@stream"])

return {k: FaunaDecoder._decode(v) for k, v in dct.items()}
11 changes: 10 additions & 1 deletion fauna/encoding/encoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime, date
from typing import Any, Optional, Set

from fauna.query.models import DocumentReference, Module, Document, NamedDocument, NamedDocumentReference, NullDocument
from fauna.query.models import DocumentReference, Module, Document, NamedDocument, NamedDocumentReference, NullDocument, StreamToken
from fauna.query.query_builder import Query, Fragment, LiteralFragment, ValueFragment

_RESERVED_TAGS = [
Expand Down Expand Up @@ -58,6 +58,8 @@ class FaunaEncoder:
+-------------------------------+---------------+
| TemplateFragment | string |
+-------------------------------+---------------+
| StreamToken | string |
+-------------------------------+---------------+
"""

Expand All @@ -76,6 +78,7 @@ def encode(obj: Any) -> Any:
- Query encodes to { "fql": [...] }
- ValueFragment encodes to { "value": <encoded_val> }
- LiteralFragment encodes to a string
- StreamToken encodes to a string
:raises ValueError: If value cannot be encoded, cannot be encoded safely, or there's a circular reference.
:param obj: the object to decode
Expand Down Expand Up @@ -151,6 +154,10 @@ def from_fragment(obj: Fragment):
def from_query_interpolation_builder(obj: Query):
return {"fql": [FaunaEncoder.from_fragment(f) for f in obj.fragments]}

@staticmethod
def from_streamtoken(obj: StreamToken):
return {"@stream": obj.token}

@staticmethod
def _encode(o: Any, _markers: Optional[Set] = None):
if _markers is None:
Expand Down Expand Up @@ -191,6 +198,8 @@ def _encode(o: Any, _markers: Optional[Set] = None):
return FaunaEncoder._encode_dict(o, _markers)
elif isinstance(o, Query):
return FaunaEncoder.from_query_interpolation_builder(o)
elif isinstance(o, StreamToken):
return FaunaEncoder.from_streamtoken(o)
else:
raise ValueError(f"Object {o} of type {type(o)} cannot be encoded")

Expand Down
7 changes: 5 additions & 2 deletions fauna/http/httpx_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,16 @@ def _send_with_retry(
else:
return self._send_with_retry(retryCount - 1, request)

# todo: decorate with context manager
def stream(
self,
url: str,
headers: Mapping[str, str],
data: Mapping[str, Any],
) -> Iterator[HTTPResponse]:
raise NotImplementedError()
) -> Iterator[Any]:
with self._c.stream("POST", url=url, headers=headers, json=data) as r:
for line in r.iter_lines():
yield json.loads(line)

def close(self):
self._c.close()
13 changes: 13 additions & 0 deletions fauna/query/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ def __ne__(self, other):
return not self.__eq__(other)


class StreamToken:
"""A class represeting a Stream in Fauna."""

def __init__(self, token: str = None):
self.token = token

def __eq__(self, other):
return isinstance(other, StreamToken) and self.token == other.token

def __hash__(self):
hash(self.token)


class Module:
"""A class representing a Module in Fauna. Examples of modules include Collection, Math, and a user-defined
collection, among others.
Expand Down
18 changes: 17 additions & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import json
from datetime import timedelta
from typing import Dict

import httpx
import pytest
import pytest_subtests
from pytest_httpx import HTTPXMock
from pytest_httpx import HTTPXMock, IteratorStream

import fauna
from fauna import fql
from fauna.client import Client, Header, QueryOptions, Endpoints
from fauna.errors import QueryCheckError, ProtocolError, QueryRuntimeError
from fauna.query.models import StreamToken
from fauna.http import HTTPXClient


Expand Down Expand Up @@ -413,3 +415,17 @@ def test_call_query_with_string():
match="'fql' must be a Query but was a <class 'str'>. You can build a Query by "
"calling fauna.fql()"):
c.query("fake") # type: ignore


def test_client_stream(subtests, httpx_mock: HTTPXMock):
response = ['{"@int": "10"}\n', '{"@long": "20"}\n']

httpx_mock.add_response(
stream=IteratorStream([bytes(r, 'utf-8') for r in response]))

with httpx.Client() as mockClient:
http_client = HTTPXClient(mockClient)
c = Client(http_client=http_client)
ret = [obj for obj in c.stream(StreamToken("token"))]

assert ret == [10, 20]
16 changes: 15 additions & 1 deletion tests/unit/test_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fauna import fql
from fauna.encoding import FaunaEncoder, FaunaDecoder
from fauna.query.models import DocumentReference, NamedDocumentReference, Document, NamedDocument, Module, Page, \
NullDocument
NullDocument, StreamToken

fixed_datetime = datetime.fromisoformat("2023-03-17T00:00:00+00:00")

Expand Down Expand Up @@ -755,3 +755,17 @@ def test_encode_query_builder_sub_queries(subtests):
}

assert expected == actual


def test_decode_stream(subtests):
with subtests.test(msg="decode @stream into StreamToken"):
test = {"@stream": "asdflkj"}
decoded = FaunaDecoder.decode(test)
assert decoded == StreamToken("asdflkj")


def test_encode_stream(subtests):
with subtests.test(msg="encode StreamToken into @stream"):
test = {"@stream": "asdflkj"}
encoded = FaunaEncoder.encode(StreamToken("asdflkj"))
assert encoded == test
23 changes: 23 additions & 0 deletions tests/unit/test_httpx_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import json

import httpx
from pytest_httpx import HTTPXMock, IteratorStream

from fauna.client import Client
from fauna.http import HTTPXClient


def test_httx_client_stream(subtests, httpx_mock: HTTPXMock):
expected = [{"@int": "10"}, {"@long": "20"}]

def to_json_bytes(obj):
return bytes(json.dumps(obj) + "\n", "utf-8")

httpx_mock.add_response(
stream=IteratorStream([to_json_bytes(obj) for obj in expected]))

with httpx.Client() as mockClient:
http_client = HTTPXClient(mockClient)
ret = [obj for obj in http_client.stream("http://localhost:8443", {}, {})]

assert ret == expected

0 comments on commit 42f0dee

Please sign in to comment.