Skip to content

Commit

Permalink
Adds StreamOptions class + integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
marrony committed Mar 8, 2024
1 parent 9eed59a commit c956ed1
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 25 deletions.
2 changes: 1 addition & 1 deletion fauna/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .client import Client, QueryOptions
from .client import Client, QueryOptions, StreamOptions
from .endpoints import Endpoints
from .headers import Header
58 changes: 46 additions & 12 deletions fauna/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fauna.errors import AuthenticationError, ClientError, ProtocolError, ServiceError, AuthorizationError, \
ServiceInternalError, ServiceTimeoutError, ThrottlingError, QueryTimeoutError, QueryRuntimeError, \
QueryCheckError, ContendedTransactionError, AbortError, InvalidRequestError, RetryableFaunaException, \
NetworkError
NetworkError, StreamTimeout
from fauna.client.headers import _DriverEnvironment, _Header, _Auth, Header
from fauna.http.http_client import HTTPClient
from fauna.query import Query, Page, fql
Expand Down Expand Up @@ -51,6 +51,19 @@ class QueryOptions:
additional_headers: Optional[Dict[str, str]] = None


@dataclass
class StreamOptions:
"""
A dataclass representing options available for a stream.
* idle_timeout - Controls the maximum amount of time the driver will wait on stream idling.
* retry_on_timeout - If true, streaming is reconnected on timeouts.
"""

idle_timeout: Optional[timedelta] = None
retry_on_timeout: bool = True


class Client:

def __init__(
Expand Down Expand Up @@ -278,7 +291,7 @@ def query(
except Exception as e:
raise ClientError("Failed to encode Query") from e

retryable = Retryable(
retryable = Retryable[QuerySuccess](
self._max_attempts,
self._max_backoff,
self._query,
Expand Down Expand Up @@ -378,7 +391,11 @@ def _query(
schema_version=schema_version,
)

def stream(self, fql: Union[StreamToken, Query]) -> "StreamIterator":
def stream(
self,
fql: Union[StreamToken, Query],
opts: StreamOptions = StreamOptions()
) -> "StreamIterator":
if isinstance(fql, Query):
token = self.query(fql).data
else:
Expand All @@ -393,7 +410,7 @@ def stream(self, fql: Union[StreamToken, Query]) -> "StreamIterator":
headers[_Header.Authorization] = self._auth.bearer()

return StreamIterator(self._session, headers, self._endpoint + "/stream/1",
self._max_attempts, self._max_backoff, token)
self._max_attempts, self._max_backoff, opts, token)

def _check_protocol(self, response_json: Any, status_code):
# TODO: Logic to validate wire protocol belongs elsewhere.
Expand Down Expand Up @@ -602,14 +619,15 @@ def _set_endpoint(self, endpoint):
class StreamIterator:
"""A class that mixes a ContextManager and an Iterator so we can detected retryable errors."""

def __init__(self, http_client: HTTPClient, headers: Dict[str,
str], endpoint: str,
max_attempts: int, max_backoff: int, token: StreamToken):
def __init__(self, http_client: HTTPClient, headers: Dict[str, str],
endpoint: str, max_attempts: int, max_backoff: int,
opts: StreamOptions, token: StreamToken):
self.http_client = http_client
self.headers = headers
self.endpoint = endpoint
self.max_attempts = max_attempts
self.max_backoff = max_backoff
self.opts = opts
self.token = token
self.stream = None
self.last_ts = None
Expand All @@ -628,8 +646,8 @@ def __iter__(self):
return self

def __next__(self):
retryable = Retryable(self.max_attempts, self.max_backoff,
self._next_element)
retryable = Retryable[Any](self.max_attempts, self.max_backoff,
self._next_element)
return retryable.run().response

def _next_element(self):
Expand All @@ -641,12 +659,27 @@ def _next_element(self):
return event

raise StopIteration
except StreamTimeout as e:
if self.opts.retry_on_timeout:
self._retry_stream(e)
raise StopIteration

except NetworkError as e:
self.ctx = self._create_stream()
self.stream = self.ctx.__enter__()
raise RetryableFaunaException from e
self._retry_stream(e)

def _retry_stream(self, e):
if self.stream is not None:
self.stream.close()
self.ctx = self._create_stream()
self.stream = self.ctx.__enter__()
raise RetryableFaunaException from e

def _create_stream(self):
if self.opts.idle_timeout:
timeout = self.opts.idle_timeout.total_seconds()
else:
timeout = None

data: Dict[str, Any] = {"token": self.token.token}
if self.last_ts is not None:
data["start_ts"] = self.last_ts
Expand All @@ -655,6 +688,7 @@ def _create_stream(self):
url=self.endpoint,
headers=self.headers,
data=data,
timeout=timeout,
)

def close(self):
Expand Down
19 changes: 11 additions & 8 deletions fauna/client/retryable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from random import random
from time import sleep
from typing import Callable, Optional
from typing import Callable, Optional, TypeVar, Generic

from fauna.encoding import QuerySuccess
from fauna.errors import RetryableFaunaException, ClientError
Expand All @@ -28,15 +28,18 @@ def wait(self) -> float:
return min(backoff, self._max_backoff)


T = TypeVar('T')


@dataclass
class RetryableResponse:
class RetryableResponse(Generic[T]):
attempts: int
response: QuerySuccess
response: T


class Retryable:
class Retryable(Generic[T]):
"""
Retryable is a wrapper class that acts on a Callable that returns a QuerySuccess.
Retryable is a wrapper class that acts on a Callable that returns a T type.
"""
_strategy: RetryStrategy
_error: Optional[Exception]
Expand All @@ -45,7 +48,7 @@ def __init__(
self,
max_attempts: int,
max_backoff: int,
func: Callable[..., QuerySuccess],
func: Callable[..., T],
*args,
**kwargs,
):
Expand All @@ -56,7 +59,7 @@ def __init__(
self._kwargs = kwargs
self._error = None

def run(self) -> RetryableResponse:
def run(self) -> RetryableResponse[T]:
"""Runs the wrapped function. Retries up to max_attempts if the function throws a RetryableFaunaException. It propagates
the thrown exception if max_attempts is reached or if a non-retryable is thrown.
Expand All @@ -70,7 +73,7 @@ def run(self) -> RetryableResponse:
try:
attempt += 1
qs = self._func(*self._args, **self._kwargs)
return RetryableResponse(attempt, qs)
return RetryableResponse[T](attempt, qs)
except RetryableFaunaException as e:
if attempt >= self._max_attempts:
raise e
2 changes: 1 addition & 1 deletion fauna/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .errors import ProtocolError, ServiceError
from .errors import AuthenticationError, AuthorizationError, QueryCheckError, QueryRuntimeError, \
QueryTimeoutError, ServiceInternalError, ServiceTimeoutError, ThrottlingError, ContendedTransactionError, \
InvalidRequestError, AbortError, RetryableFaunaException
InvalidRequestError, AbortError, RetryableFaunaException, StreamTimeout
5 changes: 5 additions & 0 deletions fauna/errors/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ class ClientError(FaunaException):
pass


class StreamTimeout(FaunaException):
"""An error representing Straming timeouts."""
pass


class NetworkError(FaunaException):
"""An error representing a failure due to the network.
This indicates Fauna was never reached."""
Expand Down
3 changes: 2 additions & 1 deletion fauna/http/http_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import contextlib

from typing import Iterator, Mapping, Any
from typing import Iterator, Mapping, Any, Optional
from dataclasses import dataclass


Expand Down Expand Up @@ -69,6 +69,7 @@ def stream(
url: str,
headers: Mapping[str, str],
data: Mapping[str, Any],
timeout: Optional[float],
) -> Iterator[Any]:
pass

Expand Down
8 changes: 6 additions & 2 deletions fauna/http/httpx_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import httpx

from fauna.errors import ClientError, NetworkError
from fauna.errors import ClientError, NetworkError, StreamTimeout
from fauna.http.http_client import HTTPResponse, HTTPClient


Expand Down Expand Up @@ -100,8 +100,10 @@ def stream(
url: str,
headers: Mapping[str, str],
data: Mapping[str, Any],
timeout: Optional[float] = None,
) -> Iterator[Any]:
stream = self._c.stream("POST", url=url, headers=headers, json=data)
stream = self._c.stream(
"POST", url=url, headers=headers, json=data, timeout=timeout)
with stream as response:
yield self._transform(response)

Expand All @@ -111,6 +113,8 @@ def _transform(self, response):
yield json.loads(line)
except httpx.StreamError as e:
raise StopIteration
except httpx.ReadTimeout as e:
raise StreamTimeout("Stream timeout") from e
except (httpx.HTTPError, httpx.InvalidURL) as e:
raise NetworkError("Exception re-raised from HTTP request") from e

Expand Down
54 changes: 54 additions & 0 deletions tests/integration/test_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import threading
from datetime import timedelta

import pytest

from fauna import fql
from fauna.client import StreamOptions


def test_stream(client, a_collection):

opts = StreamOptions(
idle_timeout=timedelta(seconds=1), retry_on_timeout=False)

def thread_fn():
stream = client.stream(
fql("${coll}.all().toStream()", coll=a_collection), opts)

with stream as iter:
events = [evt["type"] for evt in iter]

assert events == ["start", "add", "remove", "add"]

stream_thread = threading.Thread(target=thread_fn)
stream_thread.start()

id = client.query(fql("${coll}.create({}).id", coll=a_collection)).data
client.query(fql("${coll}.byId(${id})!.delete()", coll=a_collection, id=id))
client.query(fql("${coll}.create({}).id", coll=a_collection))

stream_thread.join()


def test_retry_on_timeout(client, a_collection):

opts = StreamOptions(
idle_timeout=timedelta(seconds=0.1), retry_on_timeout=True)

def thread_fn():
stream = client.stream(
fql("${coll}.all().toStream()", coll=a_collection), opts)

events = []
with stream as iter:
for evt in iter:
events.append(evt["type"])
if len(events) == 4:
iter.close()

assert events == ["start", "start", "start", "start"]

stream_thread = threading.Thread(target=thread_fn)
stream_thread.start()
stream_thread.join()

0 comments on commit c956ed1

Please sign in to comment.