Skip to content

Commit

Permalink
fix: use disambiguated name for rpcs to avoid collisions (#2217)
Browse files Browse the repository at this point in the history
  • Loading branch information
parthea authored Oct 10, 2024
1 parent de46272 commit 296cd3e
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
{% if method.operation_service %}{# Extended Operations LRO #}
def {{ method.name|snake_case }}_unary(self,
{% else %}
def {{ method.name|snake_case }}(self,
def {{ method.safe_name|snake_case }}(self,
{% endif %}{# Extended Operations LRO #}
{% if not method.client_streaming %}
request: Optional[Union[{{ method.input.ident }}, dict]] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def test_{{ service.client_name|snake_case }}_create_channel_credentials_file(cl
{% endif %}


{% for method in service.methods.values() if 'grpc' in opts.transport %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}
{% for method in service.methods.values() if 'grpc' in opts.transport %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.safe_name|snake_case %}
@pytest.mark.parametrize("request_type", [
{{ method.input.ident }},
dict,
Expand Down Expand Up @@ -579,7 +579,7 @@ def test_{{ method_name }}(request_type, transport: str = 'grpc'):
)
{% endif %}
{% if method.client_streaming %}
response = client.{{ method.name|snake_case }}(iter(requests))
response = client.{{ method.safe_name|snake_case }}(iter(requests))
{% else %}
response = client.{{ method_name }}(request)
{% endif %}
Expand Down Expand Up @@ -1053,7 +1053,7 @@ def test_{{ method_name }}_raw_page_lro():

{% endfor %} {# method in methods for grpc #}

{% for method in service.methods.values() if 'rest' in opts.transport %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}{% if method.http_options %}
{% for method in service.methods.values() if 'rest' in opts.transport %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.safe_name|snake_case %}{% if method.http_options %}
{# TODO(kbandes): remove this if condition when client streaming are supported. #}
{% if not method.client_streaming %}
@pytest.mark.parametrize("request_type", [
Expand Down Expand Up @@ -1250,7 +1250,7 @@ def test_{{ method.name|snake_case }}_rest(request_type):
response_value._content = json_return_value.encode('UTF-8')
req.return_value = response_value
{% if method.client_streaming %}
response = client.{{ method.name|snake_case }}(iter(requests))
response = client.{{ method.safe_name|snake_case }}(iter(requests))
{% elif method.server_streaming %}
with mock.patch.object(response_value, 'iter_content') as iter_content:
iter_content.return_value = iter(json_return_value)
Expand Down Expand Up @@ -1546,7 +1546,7 @@ def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_typ
response_value.request = Request()
req.return_value = response_value
{% if method.client_streaming %}
client.{{ method.name|snake_case }}(iter(requests))
client.{{ method.safe_name|snake_case }}(iter(requests))
{% else %}
client.{{ method_name }}(request)
{% endif %}
Expand Down Expand Up @@ -1814,7 +1814,7 @@ def test_{{ method_name }}_rest_no_http_options():
{% endfor -%} {#- method in methods for rest #}

{% for method in service.methods.values() if 'rest' in opts.transport and
not method.http_options %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}
not method.http_options %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.safe_name|snake_case %}
def test_{{ method_name }}_rest_error():
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def test_{{ service.client_name|snake_case }}_create_channel_credentials_file(cl
{% endfor -%} {#- method in methods for rest #}

{% for method in service.methods.values() if 'rest' in opts.transport and
not method.http_options %}{% with method_name = (method.name + ("_unary" if method.operation_service else "")) | snake_case %}
not method.http_options %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.safe_name|snake_case %}
def test_{{ method_name }}_rest_error():
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
Expand Down
20 changes: 10 additions & 10 deletions gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ async def test_{{ method_name }}_async_use_cached_wrapped_rpc(transport: str = "

{% if method.client_streaming %}
request = [{}]
await client.{{ method.name|snake_case }}(request)
await client.{{ method.safe_name|snake_case }}(request)
{% else %}
request = {}
await client.{{ method_name }}(request)
Expand All @@ -255,7 +255,7 @@ async def test_{{ method_name }}_async_use_cached_wrapped_rpc(transport: str = "
{% endif %}

{% if method.client_streaming %}
await client.{{ method.name|snake_case }}(request)
await client.{{ method.safe_name|snake_case }}(request)
{% else %}
await client.{{ method_name }}(request)
{% endif %}
Expand Down Expand Up @@ -321,9 +321,9 @@ async def test_{{ method_name }}_async(transport: str = 'grpc_asyncio', request_
))
{% endif %}
{% if method.client_streaming and method.server_streaming %}
response = await client.{{ method.name|snake_case }}(iter(requests))
response = await client.{{ method.safe_name|snake_case }}(iter(requests))
{% elif method.client_streaming and not method.server_streaming %}
response = await (await client.{{ method.name|snake_case }}(iter(requests)))
response = await (await client.{{ method.safe_name|snake_case }}(iter(requests)))
{% else %}
response = await client.{{ method_name }}(request)
{% endif %}
Expand Down Expand Up @@ -1001,7 +1001,7 @@ def test_{{ method_name }}_raw_page_lro():
{% endmacro %}

{% macro rest_required_tests(method, service, numeric_enums=False, full_extended_lro=False) %}
{% with method_name = method.safe_name|snake_case + "_unary" if method.extended_lro and not full_extended_lro else method.name|snake_case, method_output = method.extended_lro.operation_type if method.extended_lro and not full_extended_lro else method.output %}{% if method.http_options %}
{% with method_name = method.safe_name|snake_case + "_unary" if method.extended_lro and not full_extended_lro else method.safe_name|snake_case, method_output = method.extended_lro.operation_type if method.extended_lro and not full_extended_lro else method.output %}{% if method.http_options %}
{# TODO(kbandes): remove this if condition when lro and client streaming are supported. #}
{% if not method.client_streaming %}
def test_{{ method_name }}_rest_use_cached_wrapped_rpc():
Expand Down Expand Up @@ -1460,7 +1460,7 @@ def test_{{ method_name }}_rest_no_http_options():
#}
{% macro method_call_test_generic(test_name, method, service, api, transport, request_dict, is_async=False, routing_param=None) %}
{% set transport_name = get_transport_name(transport, is_async) %}
{% with method_name = (method.name + ("_unary" if method.operation_service else "")) | snake_case %}
{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.safe_name|snake_case %}
{% set async_method_prefix = "async " if is_async else "" %}
{% if is_async %}
@pytest.mark.asyncio
Expand Down Expand Up @@ -1713,7 +1713,7 @@ def test_unsupported_parameter_rest_asyncio():
{% set async_prefix = get_async_prefix(is_async) %}
{% set async_decorator = get_async_decorator(is_async) %}
{% set transport_name = get_transport_name(transport, is_async) %}
{% set method_name = method.name|snake_case %}
{% set method_name = method.safe_name|snake_case %}
{{async_decorator}}
{{async_prefix}}def test_{{ method_name }}_{{transport_name}}_error():
{% if transport_name == 'rest_asyncio' %}
Expand Down Expand Up @@ -1763,7 +1763,7 @@ def test_initialize_client_w_{{transport_name}}():
{% set async_prefix = get_async_prefix(is_async) %}
{% set async_decorator = get_async_decorator(is_async) %}
{% set transport_name = get_transport_name(transport, is_async) %}
{% set method_name = method.name|snake_case %}
{% set method_name = method.safe_name|snake_case %}
{% set mocked_session = "AsyncAuthorizedSession" if is_async else "Session" %}
{{ async_decorator }}
{{ async_prefix }}def test_{{ method_name }}_{{transport_name}}_bad_request(request_type={{ method.input.ident }}):
Expand Down Expand Up @@ -1862,7 +1862,7 @@ def test_initialize_client_w_{{transport_name}}():
{% set async_prefix = get_async_prefix(is_async) %}
{% set async_decorator = get_async_decorator(is_async) %}
{% set transport_name = get_transport_name(transport, is_async) %}
{% set method_name = method.name|snake_case %}
{% set method_name = method.safe_name|snake_case %}
{# NOTE: set method_output to method.extended_lro.operation_type for the following method types:
# (method.extended_lro and not full_extended_lro)
#}
Expand Down Expand Up @@ -2183,7 +2183,7 @@ def test_initialize_client_w_{{transport_name}}():
{% set async_prefix = get_async_prefix(is_async) %}
{% set async_decorator = get_async_decorator(is_async) %}
{% set transport_name = get_transport_name(transport, is_async) %}
{% set method_name = method.name|snake_case %}
{% set method_name = method.safe_name|snake_case %}
{% set async_method_prefix = "Async" if is_async else "" %}
{{async_decorator}}
@pytest.mark.parametrize("null_interceptor", [True, False])
Expand Down
15 changes: 15 additions & 0 deletions tests/fragments/test_reserved_method_names.proto
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ service MyService {
};
};

rpc Import(CreateImportRequest) returns (CreateImportResponse) {
option (google.api.http) = {
body: "*"
post: "/import/v1"
};
};

rpc GrpcChannel(GrpcChannelRequest) returns (GrpcChannelResponse) {
option (google.api.http) = {
body: "*"
Expand Down Expand Up @@ -59,6 +66,14 @@ message CreateChannelResponse {
string info = 1;
}

message CreateImportRequest {
string info = 1;
}

message CreateImportResponse {
string info = 1;
}

message GrpcChannelRequest {
string grpc_channel = 1;
string info = 2;
Expand Down

0 comments on commit 296cd3e

Please sign in to comment.