Skip to content

Commit

Permalink
chore: refactor mixins call method
Browse files Browse the repository at this point in the history
  • Loading branch information
ohmayr committed Sep 23, 2024
1 parent d49df24 commit abdba51
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,61 @@ class {{ async_method_name_prefix }}{{ service.name }}RestInterceptor:
{% endfor %}
{% endif %}
{% endmacro %}

{% macro generate_mixin_call_method(service, api, name, sig, is_async) %}
{% set async_prefix = "async " if is_async else "" %}
{% set async_method_name_prefix = "Async" if is_async else "" %}
{% set async_suffix = "_async" if is_async else "" %}
{% set await_prefix = "await " if is_async else "" %}

@property
def {{ name|snake_case }}(self):
return self._{{ name }}(self._session, self._host, self._interceptor) # type: ignore

class _{{ name }}(_Base{{ service.name }}RestTransport._Base{{name}}, {{ async_method_name_prefix }}{{service.name}}RestStub):
{% set body_spec = api.mixin_http_options["{}".format(name)][0].body %}
{{ response_method(body_spec) | indent(4) }}

{{ async_prefix }}def __call__(self,
request: {{ sig.request_type }}, *,
retry: OptionalRetry=gapic_v1.method.DEFAULT,
timeout: Optional[float]=None,
metadata: Sequence[Tuple[str, str]]=(),
) -> {{ sig.response_type }}:

r"""Call the {{- ' ' -}}
{{ (name|snake_case).replace('_',' ')|wrap(
width=70, offset=45, indent=8) }}
{{- ' ' -}} method over HTTP.

Args:
request ({{ sig.request_type }}):
The request object for {{ name }} method.
retry (google.api_core.retry{{ async_suffix }}.{{ async_method_name_prefix }}Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
metadata (Sequence[Tuple[str, str]]): Strings which should be
sent along with the request as metadata.
{% if sig.response_type != 'None' %}

Returns:
{{ sig.response_type }}: Response from {{ name }} method.
{% endif %}
"""
{{ rest_call_method_common(body_spec, name, service.name, is_async)|indent(4) }}

{% if sig.response_type == "None" %}
return {{ await_prefix }}self._interceptor.post_{{ name|snake_case }}(None)
{% else %}
{% if is_async %}
content = await response.read()
{% else %}
content = response.content.decode("utf-8")
{% endif %}
resp = {{ sig.response_type }}()
resp = json_format.Parse(content, resp)
resp = {{ await_prefix }}self._interceptor.post_{{ name|snake_case }}(resp)
return resp
{% endif %}

{% endmacro %}
Original file line number Diff line number Diff line change
Expand Up @@ -17,53 +17,7 @@
{% import "%namespace/%name_%version/%sub/services/%service/_shared_macros.j2" as shared_macros %}

{% if "rest" in opts.transport %}

{% for name, sig in api.mixin_api_signatures.items() %}
@property
def {{ name|snake_case }}(self):
return self._{{ name }}(self._session, self._host, self._interceptor) # type: ignore

class _{{ name }}(_Base{{ service.name }}RestTransport._Base{{name}}, {{service.name}}RestStub):
{% set body_spec = api.mixin_http_options["{}".format(name)][0].body %}
{{ shared_macros.response_method(body_spec)|indent(8) }}

def __call__(self,
request: {{ sig.request_type }}, *,
retry: OptionalRetry=gapic_v1.method.DEFAULT,
timeout: Optional[float]=None,
metadata: Sequence[Tuple[str, str]]=(),
) -> {{ sig.response_type }}:

r"""Call the {{- ' ' -}}
{{ (name|snake_case).replace('_',' ')|wrap(
width=70, offset=45, indent=8) }}
{{- ' ' -}} method over HTTP.

Args:
request ({{ sig.request_type }}):
The request object for {{ name }} method.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
metadata (Sequence[Tuple[str, str]]): Strings which should be
sent along with the request as metadata.
{% if sig.response_type != 'None' %}

Returns:
{{ sig.response_type }}: Response from {{ name }} method.
{% endif %}
"""
{{ shared_macros.rest_call_method_common(body_spec, name, service.name)|indent(8) }}

{% if sig.response_type == "None" %}
return self._interceptor.post_{{ name|snake_case }}(None)
{% else %}

resp = {{ sig.response_type }}()
resp = json_format.Parse(response.content.decode("utf-8"), resp)
resp = self._interceptor.post_{{ name|snake_case }}(resp)
return resp
{% endif %}

{{ shared_macros.generate_mixin_call_method(service, api, name, sig, is_async=False) | indent(4) }}
{% endfor %}
{% endif %} {# rest in opts.transport #}
Original file line number Diff line number Diff line change
Expand Up @@ -2677,8 +2677,9 @@ def __call__(self,
if response.status_code >= 400:
raise core_exceptions.from_http_response(response)

content = response.content.decode("utf-8")
resp = operations_pb2.Operation()
resp = json_format.Parse(response.content.decode("utf-8"), resp)
resp = json_format.Parse(content, resp)
resp = self._interceptor.post_get_operation(resp)
return resp

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2405,8 +2405,9 @@ def __call__(self,
if response.status_code >= 400:
raise core_exceptions.from_http_response(response)

content = response.content.decode("utf-8")
resp = locations_pb2.Location()
resp = json_format.Parse(response.content.decode("utf-8"), resp)
resp = json_format.Parse(content, resp)
resp = self._interceptor.post_get_location(resp)
return resp

Expand Down Expand Up @@ -2474,8 +2475,9 @@ def __call__(self,
if response.status_code >= 400:
raise core_exceptions.from_http_response(response)

content = response.content.decode("utf-8")
resp = locations_pb2.ListLocationsResponse()
resp = json_format.Parse(response.content.decode("utf-8"), resp)
resp = json_format.Parse(content, resp)
resp = self._interceptor.post_list_locations(resp)
return resp

Expand Down Expand Up @@ -2543,8 +2545,9 @@ def __call__(self,
if response.status_code >= 400:
raise core_exceptions.from_http_response(response)

content = response.content.decode("utf-8")
resp = policy_pb2.Policy()
resp = json_format.Parse(response.content.decode("utf-8"), resp)
resp = json_format.Parse(content, resp)
resp = self._interceptor.post_get_iam_policy(resp)
return resp

Expand Down Expand Up @@ -2615,8 +2618,9 @@ def __call__(self,
if response.status_code >= 400:
raise core_exceptions.from_http_response(response)

content = response.content.decode("utf-8")
resp = policy_pb2.Policy()
resp = json_format.Parse(response.content.decode("utf-8"), resp)
resp = json_format.Parse(content, resp)
resp = self._interceptor.post_set_iam_policy(resp)
return resp

Expand Down Expand Up @@ -2687,8 +2691,9 @@ def __call__(self,
if response.status_code >= 400:
raise core_exceptions.from_http_response(response)

content = response.content.decode("utf-8")
resp = iam_policy_pb2.TestIamPermissionsResponse()
resp = json_format.Parse(response.content.decode("utf-8"), resp)
resp = json_format.Parse(content, resp)
resp = self._interceptor.post_test_iam_permissions(resp)
return resp

Expand Down Expand Up @@ -2885,8 +2890,9 @@ def __call__(self,
if response.status_code >= 400:
raise core_exceptions.from_http_response(response)

content = response.content.decode("utf-8")
resp = operations_pb2.Operation()
resp = json_format.Parse(response.content.decode("utf-8"), resp)
resp = json_format.Parse(content, resp)
resp = self._interceptor.post_get_operation(resp)
return resp

Expand Down Expand Up @@ -2954,8 +2960,9 @@ def __call__(self,
if response.status_code >= 400:
raise core_exceptions.from_http_response(response)

content = response.content.decode("utf-8")
resp = operations_pb2.ListOperationsResponse()
resp = json_format.Parse(response.content.decode("utf-8"), resp)
resp = json_format.Parse(content, resp)
resp = self._interceptor.post_list_operations(resp)
return resp

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1593,8 +1593,9 @@ def __call__(self,
if response.status_code >= 400:
raise core_exceptions.from_http_response(response)

content = response.content.decode("utf-8")
resp = locations_pb2.Location()
resp = json_format.Parse(response.content.decode("utf-8"), resp)
resp = json_format.Parse(content, resp)
resp = self._interceptor.post_get_location(resp)
return resp

Expand Down Expand Up @@ -1662,8 +1663,9 @@ def __call__(self,
if response.status_code >= 400:
raise core_exceptions.from_http_response(response)

content = response.content.decode("utf-8")
resp = locations_pb2.ListLocationsResponse()
resp = json_format.Parse(response.content.decode("utf-8"), resp)
resp = json_format.Parse(content, resp)
resp = self._interceptor.post_list_locations(resp)
return resp

Expand Down Expand Up @@ -1857,8 +1859,9 @@ def __call__(self,
if response.status_code >= 400:
raise core_exceptions.from_http_response(response)

content = response.content.decode("utf-8")
resp = operations_pb2.Operation()
resp = json_format.Parse(response.content.decode("utf-8"), resp)
resp = json_format.Parse(content, resp)
resp = self._interceptor.post_get_operation(resp)
return resp

Expand Down Expand Up @@ -1926,8 +1929,9 @@ def __call__(self,
if response.status_code >= 400:
raise core_exceptions.from_http_response(response)

content = response.content.decode("utf-8")
resp = operations_pb2.ListOperationsResponse()
resp = json_format.Parse(response.content.decode("utf-8"), resp)
resp = json_format.Parse(content, resp)
resp = self._interceptor.post_list_operations(resp)
return resp

Expand Down

0 comments on commit abdba51

Please sign in to comment.