diff --git a/eagr/reflection/grpc_reflection_interface.py b/eagr/reflection/grpc_reflection_interface.py index 874bc92..5346c82 100644 --- a/eagr/reflection/grpc_reflection_interface.py +++ b/eagr/reflection/grpc_reflection_interface.py @@ -13,7 +13,7 @@ def _make_json_to_json_method_invocation( - self, method: Callable, proto_type: Callable, max_retries: int + method: Callable, proto_type: Callable, max_retries: int ) -> Callable: """Make function wrapping grpc method into json conversion""" # This invocation definition serves as a lambda for a GRPC method invocation @@ -29,23 +29,23 @@ def invocation(input_dict: Dict) -> Dict: return invocation -def make_json_grpc_client(self, host: str, service_name: str) -> Dict[str, Callable]: +def make_json_grpc_client(host: str, service_name: str) -> Dict[str, Callable]: """Mount all GRPC methods for service onto client.""" client = {} # map method names to method invocations - self._channel = grpc.insecure_channel(host, (("grpc.lb_policy_name", "round_robin"),)) + channel = grpc.insecure_channel(host, (("grpc.lb_policy_name", "round_robin"),)) build_database_from_channel = reflection_descriptor_database.build_database_from_channel - descriptor_pool_instance, symbol_database_instance = build_database_from_channel(self._channel) + descriptor_pool_instance, symbol_database_instance = build_database_from_channel(channel) service_descriptor = descriptor_pool_instance.FindServiceByName(service_name) for method_descriptor in service_descriptor.methods: method_name = method_descriptor.name method_callable = method.make_grpc_unary_method( - self._channel, service_name, method_descriptor, symbol_database_instance + channel, service_name, method_descriptor, symbol_database_instance ) input_type = method_descriptor.input_type input_prototype = symbol_database_instance.GetPrototype(input_type) - method_invocation = self._make_json_to_json_method_invocation( + method_invocation = _make_json_to_json_method_invocation( method_callable, input_prototype, max_retries=MAX_RETRIES ) diff --git a/eagr/tests/reflection/test_grpc_reflection_interface.py b/eagr/tests/reflection/test_grpc_reflection_interface.py new file mode 100644 index 0000000..1e4215a --- /dev/null +++ b/eagr/tests/reflection/test_grpc_reflection_interface.py @@ -0,0 +1,84 @@ +# Copyright 2020-present Kensho Technologies, LLC. +import unittest +from unittest.mock import MagicMock, patch + +from eagr.reflection import grpc_reflection_interface, reflection_descriptor_database +from eagr.tests.reflection import utils + + +# Mocks to patch out context-specific method calls in eagr.reflection.grpc_reflection_interface +def build_database_from_channel_mock(channel): + """Returns a built DescriptorPool from a mock ServerReflectionStub, given any channel""" + return reflection_descriptor_database.build_database_from_stub(utils.reflection_client_mock) + + +def find_service_by_name_mock(service_name): + """Returns a mock GRPC service given a valid service name""" + method_mock = MagicMock() + method_mock.name = "MyTestMethod" + method_mock.input_type = MagicMock() + method_mock.input_type.prototype = build_database_from_channel_mock + + service_mock = MagicMock() + service_mock.methods = [method_mock] + if service_name == "good_service": + return service_mock + + else: + raise KeyError("Bad service {} given".format(service_name)) + + +def make_grpc_unary_method_mock(channel, service_name, method_descriptor, symbol_database_instance): + """Returns a mock callable method regardless of params provided""" + + def callable_mock(): + """Return a mock string value invariably""" + return "return_value" + + return callable_mock + + +def get_prototype_mock(input_type): + """Returns a mock callable method regardless of params provided""" + + def callable_mock(): + """Return a mock string value invariably""" + return "return_value" + + return callable_mock + + +class TestGRPCReflectionInterface(unittest.TestCase): + @patch("eagr.reflection.reflection_descriptor_database.SymbolDatabase", autospec=True) + @patch("eagr.grpc_utils.method.make_grpc_unary_method", autospec=True) + @patch("eagr.reflection.reflection_descriptor_database.DescriptorPool", autospec=True) + @patch( + "eagr.reflection.reflection_descriptor_database.build_database_from_channel", autospec=True + ) + def test_make_json_grpc_client( + self, + build_database_mock_method, + descriptor_pool_mock, + make_unary_method_mock_method, + symbol_database_mock, + ): + build_database_mock_method.side_effect = build_database_from_channel_mock + + descriptor_pool_mock_return_value = MagicMock() + descriptor_pool_mock_return_value.FindServiceByName.side_effect = find_service_by_name_mock + descriptor_pool_mock.return_value = descriptor_pool_mock_return_value + + make_unary_method_mock_method.side_effect = make_grpc_unary_method_mock + + symbol_database_mock_return_value = MagicMock() + symbol_database_mock_return_value.GetPrototype.side_effect = get_prototype_mock + symbol_database_mock.return_value.return_value = symbol_database_mock_return_value + + my_client = grpc_reflection_interface.make_json_grpc_client("my_host", "good_service") + + self.assertIn("MyTestMethod", my_client) + + with self.assertRaises(KeyError) as context: + my_client = grpc_reflection_interface.make_json_grpc_client("my_host", "service_b") + + self.assertIn("service_b", context.exception.args[0])