Skip to content

Commit

Permalink
feat: implement async rest interceptor class
Browse files Browse the repository at this point in the history
  • Loading branch information
ohmayr committed Sep 23, 2024
1 parent f1c87fc commit 0d00ebc
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,7 @@ def _get_http_options():
{% set async_class_prefix = "Async" if is_async else "" %}

http_options = _Base{{ service_name }}RestTransport._Base{{method_name}}._get_http_options()
{% if not is_async %}
{# TODO (ohmayr): Make this unconditional once REST interceptors are supported for async. Googlers,
see internal tracking issue: b/362949568. #}
request, metadata = self._interceptor.pre_{{ method_name|snake_case }}(request, metadata)
{% endif %}
request, metadata = {{ await_prefix }}self._interceptor.pre_{{ method_name|snake_case }}(request, metadata)
transcoded_request = _Base{{ service_name }}RestTransport._Base{{method_name}}._get_transcoded_request(http_options, request)

{% if body_spec %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ from .rest import {{ service.name }}RestInterceptor
ASYNC_REST_CLASSES: Tuple[str, ...]
try:
from .rest_asyncio import Async{{ service.name }}RestTransport
ASYNC_REST_CLASSES = ('Async{{ service.name }}RestTransport',)
from .rest_asyncio import Async{{ service.name }}RestInterceptor
ASYNC_REST_CLASSES = ('Async{{ service.name }}RestTransport', 'Async{{ service.name }}RestInterceptor')
HAS_REST_ASYNC = True
except ImportError: # pragma: NO COVER
ASYNC_REST_CLASSES = ()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ except ImportError as e: # pragma: NO COVER
raise ImportError("async rest transport requires google-api-core >= 2.20.0. Install google-api-core using `pip install google-api-core==2.35.0`.") from e

from google.protobuf import json_format
{% if service.has_lro %}
from google.api_core import operations_v1
{% endif %}
{% if opts.add_iam_methods or api.has_iam_mixin %}
from google.iam.v1 import iam_policy_pb2 # type: ignore
from google.iam.v1 import policy_pb2 # type: ignore
{% endif %}
{% if api.has_location_mixin %}
from google.cloud.location import locations_pb2 # type: ignore
{% endif %}

import json # type: ignore
import dataclasses
Expand All @@ -55,11 +65,13 @@ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
rest_version=google.auth.__version__
)

{# TODO: Add an `_interceptor` property once implemented #}
{{ shared_macros.create_interceptor_class(api, service, method, is_async=True) }}

@dataclasses.dataclass
class Async{{service.name}}RestStub:
_session: AsyncAuthorizedSession
_host: str
_interceptor: Async{{service.name}}RestInterceptor

class Async{{service.name}}RestTransport(_Base{{ service.name }}RestTransport):
"""Asynchronous REST backend transport for {{ service.name }}.
Expand All @@ -78,6 +90,7 @@ class Async{{service.name}}RestTransport(_Base{{ service.name }}RestTransport):
credentials: Optional[ga_credentials_async.Credentials] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
url_scheme: str = 'https',
interceptor: Optional[Async{{ service.name }}RestInterceptor] = None,
) -> None:
"""Instantiate the transport.

Expand Down Expand Up @@ -118,6 +131,7 @@ class Async{{service.name}}RestTransport(_Base{{ service.name }}RestTransport):
# we update the type hints for credentials to include asynchronous credentials in the client layer.
#}
self._session = AsyncAuthorizedSession(self._credentials) # type: ignore
self._interceptor = interceptor or Async{{ service.name }}RestInterceptor()
self._wrap_with_kind = True
self._prep_wrapped_messages(client_info)

Expand Down Expand Up @@ -190,6 +204,7 @@ class Async{{service.name}}RestTransport(_Base{{ service.name }}RestTransport):
content = await response.read()
json_format.Parse(content, pb_resp, ignore_unknown_fields=True)
{% endif %}{# if method.server_streaming #}
resp = await self._interceptor.post_{{ method.name|snake_case }}(resp)
return resp

{% endif %}{# method.void #}
Expand All @@ -207,7 +222,7 @@ class Async{{service.name}}RestTransport(_Base{{ service.name }}RestTransport):
def {{method.transport_safe_name|snake_case}}(self) -> Callable[
[{{method.input.ident}}],
{{method.output.ident}}]:
return self._{{method.name}}(self._session, self._host) # type: ignore
return self._{{method.name}}(self._session, self._host, self._interceptor) # type: ignore

{% endfor %}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2264,7 +2264,6 @@ def test_initialize_client_w_{{transport_name}}():

{# inteceptor_class_test generates tests for rest interceptors. #}
{% macro inteceptor_class_test(service, method, transport, is_async) %}
{% if not is_async %}{# TODO: Remove this guard once support for async rest interceptors is added. #}
{% set await_prefix = get_await_prefix(is_async) %}
{% set async_prefix = get_async_prefix(is_async) %}
{% set async_decorator = get_async_decorator(is_async) %}
Expand All @@ -2277,6 +2276,15 @@ def test_initialize_client_w_{{transport_name}}():
{% if 'grpc' in transport %}
raise NotImplementedError("gRPC is currently not supported for this test case.")
{% else %}{# 'rest' in transport #}
{% if transport_name == 'rest_asyncio' %}
if not HAS_GOOGLE_AUTH_AIO:
pytest.skip("google-auth >= 2.35.0 is required for async rest transport.")
elif not HAS_AIOHTTP_INSTALLED:
pytest.skip("aiohttp is required for async rest transport.")
elif not HAS_ASYNC_REST_SUPPORT_IN_CORE:
pytest.skip("google-api-core >= 2.20.0 is required for async rest transport.")

{% endif %}
transport = transports.{{async_method_prefix}}{{ service.name }}RestTransport(
credentials={{get_credentials(is_async)}},
interceptor=None if null_interceptor else transports.{{async_method_prefix}}{{ service.name}}RestInterceptor(),
Expand Down Expand Up @@ -2345,5 +2353,4 @@ def test_initialize_client_w_{{transport_name}}():
post.assert_called_once()
{% endif %}
{% endif %}{# end 'grpc' in transport #}
{% endif %}{# end not is_async #}
{% endmacro%}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
ASYNC_REST_CLASSES: Tuple[str, ...]
try:
from .rest_asyncio import AsyncCloudRedisRestTransport
ASYNC_REST_CLASSES = ('AsyncCloudRedisRestTransport',)
from .rest_asyncio import AsyncCloudRedisRestInterceptor
ASYNC_REST_CLASSES = ('AsyncCloudRedisRestTransport', 'AsyncCloudRedisRestInterceptor')
HAS_REST_ASYNC = True
except ImportError: # pragma: NO COVER
ASYNC_REST_CLASSES = ()
Expand Down
Loading

0 comments on commit 0d00ebc

Please sign in to comment.