diff --git a/src/bentoml/_internal/client/__init__.py b/src/bentoml/_internal/client/__init__.py index 4250e42b2da..2be587975b4 100644 --- a/src/bentoml/_internal/client/__init__.py +++ b/src/bentoml/_internal/client/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import functools import logging import typing as t @@ -17,8 +16,12 @@ from types import TracebackType from ..service import Service + from .grpc import AsyncGrpcClient from .grpc import GrpcClient + from .grpc import SyncGrpcClient + from .http import AsyncHTTPClient from .http import HTTPClient + from .http import SyncHTTPClient class Client(ABC): @@ -26,7 +29,13 @@ class Client(ABC): _svc: Service endpoints: list[str] + _sync_client: SyncClient + _async_client: AsyncClient + def __init__(self, svc: Service, server_url: str): + logger.warning( + "Client is deprecated and will be removed in BentoML 2.0, please use AsyncClient or SyncClient instead." + ) self._svc = svc self.server_url = server_url @@ -34,46 +43,158 @@ def __init__(self, svc: Service, server_url: str): raise BentoMLException("No APIs were found when constructing client.") self.endpoints = [] - for name, api in self._svc.apis.items(): + for name in self._svc.apis.keys(): self.endpoints.append(name) if not hasattr(self, name): - setattr( - self, name, functools.partial(self._sync_call, _bentoml_api=api) - ) + setattr(self, name, functools.partial(self.call, bentoml_api_name=name)) if not hasattr(self, f"async_{name}"): setattr( self, f"async_{name}", - functools.partial(self._call, _bentoml_api=api), + functools.partial(self.async_call, bentoml_api_name=name), ) - def call(self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any) -> t.Any: - return self._sync_call( - inp, _bentoml_api=self._svc.apis[bentoml_api_name], **kwargs + def call( + self, + bentoml_api_name: str, + inp: t.Any = None, + **kwargs: t.Any, + ) -> t.Any: + return self._sync_client.call( + inp=inp, bentoml_api_name=bentoml_api_name, **kwargs ) async def async_call( self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any + ) -> t.Any: + return await self._async_client.call( + inp=inp, bentoml_api_name=bentoml_api_name, **kwargs + ) + + @staticmethod + def wait_until_server_ready( + host: str, port: int, timeout: float = 30, **kwargs: t.Any + ) -> None: + SyncClient.wait_until_server_ready(host, port, timeout, **kwargs) + + @staticmethod + async def async_wait_until_server_ready( + host: str, port: int, timeout: float = 30, **kwargs: t.Any + ) -> None: + await AsyncClient.wait_until_server_ready(host, port, timeout, **kwargs) + + @t.overload + @staticmethod + def from_url(server_url: str, *, kind: None | t.Literal["auto"] = ...) -> Client: + ... + + @t.overload + @staticmethod + def from_url(server_url: str, *, kind: t.Literal["http"] = ...) -> HTTPClient: + ... + + @t.overload + @staticmethod + def from_url(server_url: str, *, kind: t.Literal["grpc"] = ...) -> GrpcClient: + ... + + @staticmethod + def from_url( + server_url: str, + *, + kind: t.Literal["auto", "http", "grpc"] | None = None, + **kwargs: t.Any, + ) -> Client: + return SyncClient.from_url(server_url, kind=kind, **kwargs) + + def __enter__(self): + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + return self._sync_client.close() + + async def __aenter__(self): + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + return await self._async_client.close() + + +class AsyncClient(ABC): + server_url: str + _svc: Service + endpoints: list[str] + + def __init__(self, svc: Service, server_url: str): + self._svc = svc + self.server_url = server_url + + if len(svc.apis) == 0: + raise BentoMLException("No APIs were found when constructing client.") + + self.endpoints = [] + for name, api in self._svc.apis.items(): + self.endpoints.append(name) + + if not hasattr(self, name): + setattr( + self, + name, + functools.partial(self._call, _bentoml_api=api), + ) + + async def call( + self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any ) -> t.Any: return await self._call( inp, _bentoml_api=self._svc.apis[bentoml_api_name], **kwargs ) + @abstractmethod + async def _call( + self, inp: t.Any = None, *, _bentoml_api: InferenceAPI[t.Any], **kwargs: t.Any + ) -> t.Any: + raise NotImplementedError() + + async def close(self): + pass + + async def __aenter__(self): + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + await self.close() + @staticmethod - def wait_until_server_ready( + async def wait_until_server_ready( host: str, port: int, timeout: float = 30, **kwargs: t.Any ) -> None: try: - from .http import HTTPClient + from .http import AsyncHTTPClient - HTTPClient.wait_until_server_ready(host, port, timeout, **kwargs) + await AsyncHTTPClient.wait_until_server_ready(host, port, timeout, **kwargs) except BadStatusLine: # when address is a RPC - from .grpc import GrpcClient + from .grpc import AsyncGrpcClient - GrpcClient.wait_until_server_ready(host, port, timeout, **kwargs) + await AsyncGrpcClient.wait_until_server_ready(host, port, timeout, **kwargs) except Exception as err: # caught all other exceptions logger.error("Failed to connect to server %s:%s", host, port) @@ -81,60 +202,95 @@ def wait_until_server_ready( raise @t.overload - @staticmethod - def from_url(server_url: str, *, kind: None | t.Literal["auto"] = ...) -> Client: + @classmethod + async def from_url( + cls, server_url: str, *, kind: None | t.Literal["auto"] = ... + ) -> AsyncGrpcClient | AsyncHTTPClient: ... @t.overload - @staticmethod - def from_url(server_url: str, *, kind: t.Literal["http"] = ...) -> HTTPClient: + @classmethod + async def from_url( + cls, server_url: str, *, kind: t.Literal["http"] = ... + ) -> AsyncHTTPClient: ... @t.overload - @staticmethod - def from_url(server_url: str, *, kind: t.Literal["grpc"] = ...) -> GrpcClient: + @classmethod + async def from_url( + cls, server_url: str, *, kind: t.Literal["grpc"] = ... + ) -> AsyncGrpcClient: ... - @staticmethod - def from_url( - server_url: str, *, kind: str | None = None, **kwargs: t.Any - ) -> Client: + @classmethod + async def from_url( + cls, + server_url: str, + *, + kind: t.Literal["auto", "http", "grpc"] | None = None, + **kwargs: t.Any, + ) -> AsyncClient: if kind is None or kind == "auto": try: - from .http import HTTPClient + from .http import AsyncHTTPClient - return HTTPClient.from_url(server_url, **kwargs) + return await AsyncHTTPClient.from_url(server_url, **kwargs) except BadStatusLine: - from .grpc import GrpcClient + from .grpc import AsyncGrpcClient - return GrpcClient.from_url(server_url, **kwargs) + return await AsyncGrpcClient.from_url(server_url, **kwargs) except Exception as e: # pylint: disable=broad-except raise BentoMLException( f"Failed to create a BentoML client from given URL '{server_url}': {e} ({e.__class__.__name__})" ) from e elif kind == "http": - from .http import HTTPClient + from .http import AsyncHTTPClient - return HTTPClient.from_url(server_url, **kwargs) + return await AsyncHTTPClient.from_url(server_url, **kwargs) elif kind == "grpc": - from .grpc import GrpcClient + from .grpc import AsyncGrpcClient - return GrpcClient.from_url(server_url, **kwargs) + return await AsyncGrpcClient.from_url(server_url, **kwargs) else: raise BentoMLException( f"Invalid client kind '{kind}'. Must be one of 'http', 'grpc', or 'auto'." ) - def _sync_call( - self, inp: t.Any = None, *, _bentoml_api: InferenceAPI[t.Any], **kwargs: t.Any - ): - return asyncio.run(self._call(inp, _bentoml_api=_bentoml_api, **kwargs)) + +class SyncClient(Client): + server_url: str + _svc: Service + endpoints: list[str] + + def __init__(self, svc: Service, server_url: str): + self._svc = svc + self.server_url = server_url + + if len(svc.apis) == 0: + raise BentoMLException("No APIs were found when constructing client.") + + self.endpoints = [] + for name, api in self._svc.apis.items(): + self.endpoints.append(name) + + if not hasattr(self, name): + setattr( + self, + name, + functools.partial(self._call, _bentoml_api=api), + ) + + def call(self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any) -> t.Any: + return self._call(inp, _bentoml_api=self._svc.apis[bentoml_api_name], **kwargs) @abstractmethod - async def _call( + def _call( self, inp: t.Any = None, *, _bentoml_api: InferenceAPI[t.Any], **kwargs: t.Any ) -> t.Any: - raise NotImplementedError + raise NotImplementedError() + + def close(self) -> None: + pass def __enter__(self): return self @@ -145,15 +301,78 @@ def __exit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> bool | None: - pass + self.close() - async def __aenter__(self): - return self + @staticmethod + def wait_until_server_ready( + host: str, port: int, timeout: float = 30, **kwargs: t.Any + ) -> None: + try: + from .http import SyncHTTPClient - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> bool | None: - pass + SyncHTTPClient.wait_until_server_ready(host, port, timeout, **kwargs) + except BadStatusLine: + # when address is a RPC + from .grpc import SyncGrpcClient + + SyncGrpcClient.wait_until_server_ready(host, port, timeout, **kwargs) + except Exception as err: + # caught all other exceptions + logger.error("Failed to connect to server %s:%s", host, port) + logger.error(err) + raise + + @t.overload + @classmethod + def from_url( + cls, server_url: str, *, kind: None | t.Literal["auto"] = ... + ) -> SyncGrpcClient | SyncHTTPClient: + ... + + @t.overload + @classmethod + def from_url( + cls, server_url: str, *, kind: t.Literal["http"] = ... + ) -> SyncHTTPClient: + ... + + @t.overload + @classmethod + def from_url( + cls, server_url: str, *, kind: t.Literal["grpc"] = ... + ) -> SyncGrpcClient: + ... + + @classmethod + def from_url( + cls, + server_url: str, + *, + kind: t.Literal["auto", "http", "grpc"] | None = None, + **kwargs: t.Any, + ) -> SyncClient: + if kind is None or kind == "auto": + try: + from .http import SyncHTTPClient + + return SyncHTTPClient.from_url(server_url, **kwargs) + except BadStatusLine: + from .grpc import SyncGrpcClient + + return SyncGrpcClient.from_url(server_url, **kwargs) + except Exception as e: # pylint: disable=broad-except + raise BentoMLException( + f"Failed to create a BentoML client from given URL '{server_url}': {e} ({e.__class__.__name__})" + ) from e + elif kind == "http": + from .http import SyncHTTPClient + + return SyncHTTPClient.from_url(server_url, **kwargs) + elif kind == "grpc": + from .grpc import SyncGrpcClient + + return SyncGrpcClient.from_url(server_url, **kwargs) + else: + raise BentoMLException( + f"Invalid client kind '{kind}'. Must be one of 'http', 'grpc', or 'auto'." + ) diff --git a/src/bentoml/_internal/client/grpc.py b/src/bentoml/_internal/client/grpc.py index 17f939b0977..75fed0ac465 100644 --- a/src/bentoml/_internal/client/grpc.py +++ b/src/bentoml/_internal/client/grpc.py @@ -1,6 +1,6 @@ from __future__ import annotations -import functools +import asyncio import logging import time import typing as t @@ -18,7 +18,9 @@ from ..service import Service from ..service.inference_api import InferenceAPI from ..utils import LazyLoader +from . import AsyncClient from . import Client +from . import SyncClient logger = logging.getLogger(__name__) @@ -30,6 +32,7 @@ from google.protobuf import json_format as _json_format from grpc import aio from grpc._channel import Channel as GrpcSyncChannel + from grpc.aio._channel import Channel as GrpcAsyncChannel from grpc_health.v1 import health_pb2 as pb_health from ...grpc.v1.service_pb2 import Response @@ -52,8 +55,15 @@ class ClientCredentials(t.TypedDict): ) -# TODO: xDS support class GrpcClient(Client): + def __init__(self, svc: Service, server_url: str): + self._sync_client = SyncGrpcClient(svc=svc, server_url=server_url) + self._async_client = AsyncGrpcClient(svc=svc, server_url=server_url) + super().__init__(svc, server_url) + + +# TODO: xDS support +class AsyncGrpcClient(AsyncClient): def __init__( self, server_url: str, @@ -88,8 +98,8 @@ def __init__( self._call_rpc = f"/bentoml.grpc.{protocol_version}.BentoService/Call" super().__init__(svc, server_url) - @property - def channel(self): + @cached_property + def channel(self) -> GrpcAsyncChannel: if self._credentials is not None: return aio.secure_channel( self.server_url, @@ -98,27 +108,26 @@ def channel(self): compression=self._compression, interceptors=self._interceptors, ) - else: - return aio.insecure_channel( - self.server_url, - options=self._options, - compression=self._compression, - interceptors=self._interceptors, - ) + return aio.insecure_channel( + self.server_url, + options=self._options, + compression=self._compression, + interceptors=self._interceptors, + ) @staticmethod - def _create_sync_channel( + def _create_channel( server_url: str, ssl: bool = False, ssl_client_credentials: ClientCredentials | None = None, channel_options: t.Any | None = None, compression: grpc.Compression | None = None, - ) -> GrpcSyncChannel: + ) -> GrpcAsyncChannel: if ssl: assert ( ssl_client_credentials is not None ), "'ssl=True' requires 'ssl_client_credentials'" - return grpc.secure_channel( + return aio.secure_channel( server_url, credentials=grpc.ssl_channel_credentials( **{ @@ -129,12 +138,12 @@ def _create_sync_channel( options=channel_options, compression=compression, ) - return grpc.insecure_channel( + return aio.insecure_channel( server_url, options=channel_options, compression=compression ) @staticmethod - def wait_until_server_ready( + async def wait_until_server_ready( host: str, port: int, timeout: float = 30, @@ -144,21 +153,38 @@ def wait_until_server_ready( ) -> None: protocol_version = kwargs.get("protocol_version", LATEST_PROTOCOL_VERSION) - channel = GrpcClient._create_sync_channel( + async with AsyncGrpcClient._create_channel( f"{host}:{port}", ssl=kwargs.get("ssl", False), ssl_client_credentials=kwargs.get("ssl_client_credentials", None), channel_options=kwargs.get("channel_options", None), compression=kwargs.get("compression", None), - ) - rpc = channel.unary_unary( - "/grpc.health.v1.Health/Check", - request_serializer=pb_health.HealthCheckRequest.SerializeToString, - response_deserializer=pb_health.HealthCheckResponse.FromString, - ) + ) as channel: + rpc = channel.unary_unary( + "/grpc.health.v1.Health/Check", + request_serializer=pb_health.HealthCheckRequest.SerializeToString, + response_deserializer=pb_health.HealthCheckResponse.FromString, + ) + + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = t.cast( + pb_health.HealthCheckResponse, + rpc( + pb_health.HealthCheckRequest( + service=f"bentoml.grpc.{protocol_version}.BentoService" + ) + ), + ) + if response.status == pb_health.HealthCheckResponse.SERVING: + break + else: + asyncio.sleep(check_interval) + except grpc.RpcError: + logger.debug("Server is not ready. Retrying...") + await asyncio.sleep(check_interval) - start_time = time.time() - while time.time() - start_time < timeout: try: response = t.cast( pb_health.HealthCheckResponse, @@ -168,33 +194,14 @@ def wait_until_server_ready( ) ), ) - if response.status == pb_health.HealthCheckResponse.SERVING: - break - else: - time.sleep(check_interval) - except grpc.RpcError: - logger.debug("Server is not ready. Retrying...") - time.sleep(check_interval) - - try: - response = t.cast( - pb_health.HealthCheckResponse, - rpc( - pb_health.HealthCheckRequest( - service=f"bentoml.grpc.{protocol_version}.BentoService" + if response.status != pb_health.HealthCheckResponse.SERVING: + raise TimeoutError( + f"Timed out waiting {timeout} seconds for server at '{host}:{port}' to be ready." ) - ), - ) - if response.status != pb_health.HealthCheckResponse.SERVING: - raise TimeoutError( - f"Timed out waiting {timeout} seconds for server at '{host}:{port}' to be ready." - ) - except (grpc.RpcError, TimeoutError) as err: - logger.error("Caught exception while connecting to %s:%s:", host, port) - logger.error(err) - raise - finally: - channel.close() + except (grpc.RpcError, TimeoutError) as err: + logger.error("Caught exception while connecting to %s:%s:", host, port) + logger.error(err) + raise @cached_property def _rpc_metadata(self) -> dict[str, dict[str, t.Any]]: @@ -221,41 +228,66 @@ def _rpc_metadata(self) -> dict[str, dict[str, t.Any]]: ) } - async def health(self, service_name: str, *, timeout: int = 30) -> t.Any: - return await self._invoke( - method_name="/grpc.health.v1.Health/Check", - service=service_name, - _grpc_channel_timeout=timeout, - ) + @cached_property + def _rpc_methods( + self, + ) -> dict[str, t.Callable[..., t.Awaitable["Response"]]]: + def make_async_fn( + method_name: str, + input_type: t.Any, + output_type: t.Any, + ) -> t.Callable[..., t.Awaitable["Response"]]: + rpc = self.channel.unary_unary( + method_name, + request_serializer=input_type.SerializeToString, + response_deserializer=output_type.FromString, + ) - async def _invoke(self, method_name: str, **attrs: t.Any): - # channel kwargs include timeout, metadata, credentials, wait_for_ready and compression - # to pass it in kwargs add prefix _grpc_channel_ - channel_kwargs = { - k: attrs.pop(f"_grpc_channel_{k}", None) - for k in { - "timeout", - "metadata", - "credentials", - "wait_for_ready", - "compression", - } - } - if method_name not in self._rpc_metadata: - raise ValueError( - f"'{method_name}' is a yet supported rpc. Current supported are: {self._rpc_metadata}" + def fn( + channel_kwargs: t.Dict[str, t.Any], + method_kwargs: t.Dict[str, t.Any], + ) -> t.Awaitable["Response"]: + return t.cast( + t.Awaitable["Response"], + rpc(input_type(**method_kwargs), **channel_kwargs), + ) + + return fn + + return { + method_name: make_async_fn( + method_name, + input_type=metadata["input_type"], + output_type=metadata["output_type"], ) - metadata = self._rpc_metadata[method_name] - rpc = self.channel.unary_unary( - method_name, - request_serializer=metadata["input_type"].SerializeToString, - response_deserializer=metadata["output_type"].FromString, + for method_name, metadata in self._rpc_metadata.items() + } + + async def health(self, service_name: str, *, timeout: int = 30) -> t.Any: + return await self._rpc_methods["/grpc.health.v1.Health/Check"]( + method_kwargs={"service": service_name}, + channel_kwargs={"timeout": timeout}, ) - return await t.cast( - "t.Awaitable[Response]", - rpc(metadata["input_type"](**attrs), **channel_kwargs), + @staticmethod + def _split_channel_args( + **kwargs: t.Any, + ) -> tuple[t.Dict[str, t.Any], t.Dict[str, t.Any]]: + channel_kwarg_names = ( + "timeout", + "metadata", + "credentials", + "wait_for_ready", + "compression", ) + channel_kwargs: t.Dict[str, t.Any] = {} + other_kwargs: t.Dict[str, t.Any] = {} + for k, v in kwargs.items(): + if k in channel_kwarg_names: + channel_kwargs[k] = v + else: + other_kwargs[k] = v + return other_kwargs, channel_kwargs async def _call( self, @@ -269,41 +301,392 @@ async def _call( # create a blocking call to wait til channel is ready. await self.channel.channel_ready() - fn = functools.partial( - self._invoke, - method_name=f"/bentoml.grpc.{self._protocol_version}.BentoService/Call", - **{ - f"_grpc_channel_{k}": attrs.pop(f"_grpc_channel_{k}", None) - for k in { - "timeout", - "metadata", - "credentials", - "wait_for_ready", - "compression", - } + if _bentoml_api.multi_input: + if inp is not None: + raise BentoMLException( + f"'{_bentoml_api.name}' takes multiple inputs; all inputs must be passed as keyword arguments." + ) + serialized_req = await _bentoml_api.input.to_proto(attrs) + else: + serialized_req = await _bentoml_api.input.to_proto(inp) + + # A call includes api_name and given proto_fields + api_fn = {v: k for k, v in self._svc.apis.items()} + kwargs, channel_kwargs = self._split_channel_args(**attrs) + kwargs.update( + { + "api_name": api_fn[_bentoml_api], + _bentoml_api.input.proto_fields[0]: serialized_req, }, ) + if self._call_rpc not in self._rpc_methods: + raise ValueError( + f"'{self._call_rpc}' is a yet supported rpc. Current supported are: {self._rpc_metadata}" + ) + proto = await self._rpc_methods[self._call_rpc]( + channel_kwargs=channel_kwargs, + method_kwargs=kwargs, + ) + return await _bentoml_api.output.from_proto( + getattr(proto, proto.WhichOneof("content")) + ) + + @classmethod + async def from_url(cls, server_url: str, **kwargs: t.Any) -> AsyncGrpcClient: + protocol_version = kwargs.get("protocol_version", LATEST_PROTOCOL_VERSION) + + # Since v1, we introduce a ServiceMetadata rpc to retrieve bentoml.Service metadata. + # then `client.predict` or `client.classify` won't be available. + # client.Call will still persist for both protocol version. + if parse(protocol_version) < parse("v1"): + exception_message = [ + f"Using protocol version {protocol_version} older than v1. 'bentoml.client.Client' will only support protocol version v1 onwards. To create client with protocol version '{protocol_version}', do the following:\n" + f"""\ + +from bentoml.grpc.utils import import_generated_stubs, import_grpc + +pb, services = import_generated_stubs("{protocol_version}") + +grpc, _ = import_grpc() + +def run(): + with grpc.insecure_channel("localhost:3000") as channel: + stubs = services.BentoServiceStub(channel) + req = stubs.Call( + request=pb.Request( + api_name="predict", + ndarray=pb.NDArray( + dtype=pb.NDArray.DTYPE_FLOAT, + shape=(1, 4), + float_values=[5.9, 3, 5.1, 1.8], + ), + ) + ) + print(req) + +if __name__ == '__main__': + run() +""" + ] + raise BentoMLException("\n".join(exception_message)) + pb, _ = import_generated_stubs(protocol_version) + + async with AsyncGrpcClient._create_channel( + server_url.replace(r"localhost", "0.0.0.0"), + ssl=kwargs.get("ssl", False), + ssl_client_credentials=kwargs.get("ssl_client_credentials", None), + channel_options=kwargs.get("channel_options", None), + compression=kwargs.get("compression", None), + ) as channel: + # create an insecure channel to invoke ServiceMetadata rpc + metadata = t.cast( + "ServiceMetadataResponse", + channel.unary_unary( + f"/bentoml.grpc.{protocol_version}.BentoService/ServiceMetadata", + request_serializer=pb.ServiceMetadataRequest.SerializeToString, + response_deserializer=pb.ServiceMetadataResponse.FromString, + )(pb.ServiceMetadataRequest()), + ) + dummy_service = Service(metadata.name) + + for api in metadata.apis: + try: + dummy_service.apis[api.name] = InferenceAPI[t.Any]( + None, + io_descriptors.from_spec( + { + "id": api.input.descriptor_id, + "args": _json_format.MessageToDict( + api.input.attributes + ).get("args", None), + } + ), + io_descriptors.from_spec( + { + "id": api.output.descriptor_id, + "args": _json_format.MessageToDict( + api.output.attributes + ).get("args", None), + } + ), + name=api.name, + doc=api.docs, + ) + except BentoMLException as e: + logger.error("Failed to instantiate client for API %s: ", api.name, e) + + return cls(server_url, dummy_service, **kwargs) + + async def close(self): + await self.channel.close() + return await super().close() + + +# TODO: xDS support +class SyncGrpcClient(SyncClient): + def __init__( + self, + server_url: str, + svc: Service, + # gRPC specific options + ssl: bool = False, + channel_options: aio.ChannelArgumentType | None = None, + interceptors: t.Sequence[aio.ClientInterceptor] | None = None, + compression: grpc.Compression | None = None, + ssl_client_credentials: ClientCredentials | None = None, + *, + protocol_version: str = LATEST_PROTOCOL_VERSION, + **kwargs: t.Any, + ): + self._pb, _ = import_generated_stubs(protocol_version) + + self._protocol_version = protocol_version + self._compression = compression + self._options = channel_options + self._interceptors = interceptors + self._credentials = None + if ssl: + assert ( + ssl_client_credentials is not None + ), "'ssl=True' requires 'ssl_client_credentials'" + self._credentials = grpc.ssl_channel_credentials( + **{ + k: load_from_file(v) if isinstance(v, str) else v + for k, v in ssl_client_credentials.items() + } + ) + self._call_rpc = f"/bentoml.grpc.{protocol_version}.BentoService/Call" + super().__init__(svc, server_url) + + @cached_property + def channel(self) -> GrpcSyncChannel: + if self._credentials is not None: + return grpc.secure_channel( + self.server_url, + credentials=self._credentials, + options=self._options, + compression=self._compression, + ) + return grpc.insecure_channel( + self.server_url, + options=self._options, + compression=self._compression, + ) + + @staticmethod + def _create_channel( + server_url: str, + ssl: bool = False, + ssl_client_credentials: ClientCredentials | None = None, + channel_options: t.Any | None = None, + compression: grpc.Compression | None = None, + ) -> GrpcSyncChannel: + if ssl: + assert ( + ssl_client_credentials is not None + ), "'ssl=True' requires 'ssl_client_credentials'" + return grpc.secure_channel( + server_url, + credentials=grpc.ssl_channel_credentials( + **{ + k: load_from_file(v) if isinstance(v, str) else v + for k, v in ssl_client_credentials.items() + } + ), + options=channel_options, + compression=compression, + ) + return grpc.insecure_channel( + server_url, options=channel_options, compression=compression + ) + + @staticmethod + def wait_until_server_ready( + host: str, + port: int, + timeout: float = 30, + check_interval: int = 1, + # set kwargs here to omit gRPC kwargs + **kwargs: t.Any, + ) -> None: + protocol_version = kwargs.get("protocol_version", LATEST_PROTOCOL_VERSION) + + with GrpcClient._create_channel( + f"{host}:{port}", + ssl=kwargs.get("ssl", False), + ssl_client_credentials=kwargs.get("ssl_client_credentials", None), + channel_options=kwargs.get("channel_options", None), + compression=kwargs.get("compression", None), + ) as channel: + rpc = channel.unary_unary( + "/grpc.health.v1.Health/Check", + request_serializer=pb_health.HealthCheckRequest.SerializeToString, + response_deserializer=pb_health.HealthCheckResponse.FromString, + ) + + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = t.cast( + pb_health.HealthCheckResponse, + rpc( + pb_health.HealthCheckRequest( + service=f"bentoml.grpc.{protocol_version}.BentoService" + ) + ), + ) + if response.status == pb_health.HealthCheckResponse.SERVING: + break + else: + time.sleep(check_interval) + except grpc.RpcError: + logger.debug("Server is not ready. Retrying...") + time.sleep(check_interval) + + try: + response = t.cast( + pb_health.HealthCheckResponse, + rpc( + pb_health.HealthCheckRequest( + service=f"bentoml.grpc.{protocol_version}.BentoService" + ) + ), + ) + if response.status != pb_health.HealthCheckResponse.SERVING: + raise TimeoutError( + f"Timed out waiting {timeout} seconds for server at '{host}:{port}' to be ready." + ) + except (grpc.RpcError, TimeoutError) as err: + logger.error("Caught exception while connecting to %s:%s:", host, port) + logger.error(err) + raise + + @cached_property + def _rpc_metadata(self) -> dict[str, dict[str, t.Any]]: + # Currently all RPCs in BentoService are unary-unary + # NOTE: we will set the types of the stubs to be Any. + return { + method: {"input_type": input_type, "output_type": output_type} + for method, input_type, output_type in ( + ( + self._call_rpc, + self._pb.Request, + self._pb.Response, + ), + ( + f"/bentoml.grpc.{self._protocol_version}.BentoService/ServiceMetadata", + self._pb.ServiceMetadataRequest, + self._pb.ServiceMetadataResponse, + ), + ( + "/grpc.health.v1.Health/Check", + pb_health.HealthCheckRequest, + pb_health.HealthCheckResponse, + ), + ) + } + + @cached_property + def _rpc_methods(self) -> dict[str, t.Callable[..., "Response"]]: + def make_sync_fn( + method_name: str, + input_type: t.Any, + output_type: t.Any, + ) -> t.Callable[..., "Response"]: + rpc = self.channel.unary_unary( + method_name, + request_serializer=input_type.SerializeToString, + response_deserializer=output_type.FromString, + ) + + def fn( + channel_kwargs: t.Dict[str, t.Any], + method_kwargs: t.Dict[str, t.Any], + ) -> Response: + return t.cast( + "Response", + rpc(input_type(**method_kwargs), **channel_kwargs), + ) + + return fn + + return { + method_name: make_sync_fn( + method_name, + input_type=metadata["input_type"], + output_type=metadata["output_type"], + ) + for method_name, metadata in self._rpc_metadata.items() + } + + def health(self, service_name: str, *, timeout: int = 30) -> t.Any: + return self._rpc_methods["/grpc.health.v1.Health/Check"]( + method_kwargs={"service": service_name}, + channel_kwargs={"timeout": timeout}, + ) + + @staticmethod + def _split_channel_args( + **kwargs: t.Any, + ) -> tuple[t.Dict[str, t.Any], t.Dict[str, t.Any]]: + channel_kwarg_names = ( + "timeout", + "metadata", + "credentials", + "wait_for_ready", + "compression", + ) + channel_kwargs: t.Dict[str, t.Any] = {} + other_kwargs: t.Dict[str, t.Any] = {} + for k, v in kwargs.items(): + if k in channel_kwarg_names: + channel_kwargs[k] = v + else: + other_kwargs[k] = v + return other_kwargs, channel_kwargs + + def _call( + self, + inp: t.Any = None, + *, + _bentoml_api: InferenceAPI[t.Any], + **attrs: t.Any, + ): if _bentoml_api.multi_input: if inp is not None: raise BentoMLException( f"'{_bentoml_api.name}' takes multiple inputs; all inputs must be passed as keyword arguments." ) - serialized_req = await _bentoml_api.input.to_proto(attrs) + serialized_req = asyncio.run(_bentoml_api.input.to_proto(attrs)) else: - serialized_req = await _bentoml_api.input.to_proto(inp) + serialized_req = asyncio.run(_bentoml_api.input.to_proto(inp)) # A call includes api_name and given proto_fields api_fn = {v: k for k, v in self._svc.apis.items()} - return await fn( - **{ + kwargs, channel_kwargs = self._split_channel_args(**attrs) + kwargs.update( + { "api_name": api_fn[_bentoml_api], _bentoml_api.input.proto_fields[0]: serialized_req, }, ) + if self._call_rpc not in self._rpc_methods: + raise ValueError( + f"'{self._call_rpc}' is a yet supported rpc. Current supported are: {self._rpc_metadata}" + ) + proto = self._rpc_methods[self._call_rpc]( + channel_kwargs=channel_kwargs, + method_kwargs=kwargs, + ) + + return asyncio.run( + _bentoml_api.output.from_proto(getattr(proto, proto.WhichOneof("content"))) + ) + @classmethod - def from_url(cls, server_url: str, **kwargs: t.Any) -> GrpcClient: + def from_url(cls, server_url: str, **kwargs: t.Any) -> SyncGrpcClient: protocol_version = kwargs.get("protocol_version", LATEST_PROTOCOL_VERSION) # Since v1, we introduce a ServiceMetadata rpc to retrieve bentoml.Service metadata. @@ -342,7 +725,7 @@ def run(): raise BentoMLException("\n".join(exception_message)) pb, _ = import_generated_stubs(protocol_version) - with GrpcClient._create_sync_channel( + with GrpcClient._create_channel( server_url.replace(r"localhost", "0.0.0.0"), ssl=kwargs.get("ssl", False), ssl_client_credentials=kwargs.get("ssl_client_credentials", None), @@ -387,3 +770,7 @@ def run(): logger.error("Failed to instantiate client for API %s: ", api.name, e) return cls(server_url, dummy_service, **kwargs) + + def close(self): + self.channel.close() + return super().close() diff --git a/src/bentoml/_internal/client/http.py b/src/bentoml/_internal/client/http.py index 307095c77e9..042f9ce58f6 100644 --- a/src/bentoml/_internal/client/http.py +++ b/src/bentoml/_internal/client/http.py @@ -3,30 +3,190 @@ import asyncio import json import logging -import socket import time import typing as t -import urllib.error -import urllib.request -from http.client import HTTPConnection -from urllib.parse import urlparse +from functools import cached_property -import aiohttp +import httpx import starlette.datastructures import starlette.requests from ...exceptions import BentoMLException from ...exceptions import RemoteException from .. import io_descriptors as io -from ..configuration import get_debug_mode from ..service import Service from ..service.inference_api import InferenceAPI +from . import AsyncClient from . import Client +from . import SyncClient logger = logging.getLogger(__name__) class HTTPClient(Client): + def __init__(self, svc: Service, server_url: str): + self._sync_client = SyncHTTPClient(svc=svc, server_url=server_url) + self._async_client = AsyncHTTPClient(svc=svc, server_url=server_url) + super().__init__(svc, server_url) + + +class AsyncHTTPClient(AsyncClient): + @cached_property + def client(self) -> httpx.AsyncClient: + return httpx.AsyncClient(base_url=self.server_url, timeout=300) + + @staticmethod + async def wait_until_server_ready( + host: str, + port: int, + timeout: float = 30, + check_interval: int = 1, + # set kwargs here to omit gRPC kwargs + **kwargs: t.Any, + ) -> None: + start_time = time.time() + + logger.debug("Waiting for host %s to be ready.", f"{host}:{port}") + while time.time() - start_time < timeout: + try: + async with httpx.AsyncClient(base_url=f"{host}:{port}") as session: + resp = await session.get("/readyz") + if resp.status_code == 200: + break + else: + await asyncio.sleep(check_interval) + except ( + httpx.TimeoutException, + httpx.NetworkError, + httpx.HTTPStatusError, + ): + logger.debug("Server is not ready. Retrying...") + await asyncio.sleep(check_interval) + + # try to connect one more time and raise exception. + try: + async with httpx.AsyncClient(base_url=f"{host}:{port}") as session: + resp = await session.get("/readyz") + if resp.status_code != 200: + raise TimeoutError( + f"Timed out waiting {timeout} seconds for server at '{host}:{port}' to be ready." + ) + except ( + httpx.TimeoutException, + httpx.NetworkError, + httpx.HTTPStatusError, + ) as err: + logger.error("Timed out while connecting to %s:%s:", host, port) + logger.error(err) + raise + + async def health(self) -> httpx.Response: + return await self.client.get("/readyz") + + @classmethod + async def from_url(cls, server_url: str, **kwargs: t.Any) -> AsyncHTTPClient: + server_url = server_url if "://" in server_url else "http://" + server_url + + async with httpx.AsyncClient(base_url=server_url) as session: + resp = await session.get("/docs.json") + if resp.status_code != 200: + raise RemoteException( + f"Failed to get OpenAPI schema from the server: {resp.status_code} {resp.reason_phrase}:\n{await resp.aread()}" + ) + openapi_spec = json.loads(await resp.aread()) + + dummy_service = Service(openapi_spec["info"]["title"]) + + for route, spec in openapi_spec["paths"].items(): + for meth_spec in spec.values(): + if "tags" in meth_spec and "Service APIs" in meth_spec["tags"]: + if "x-bentoml-io-descriptor" not in meth_spec["requestBody"]: + # TODO: better message stating min version for from_url to work + raise BentoMLException( + f"Malformed BentoML spec received from BentoML server {server_url}" + ) + if "x-bentoml-io-descriptor" not in meth_spec["responses"]["200"]: + raise BentoMLException( + f"Malformed BentoML spec received from BentoML server {server_url}" + ) + if "x-bentoml-name" not in meth_spec: + raise BentoMLException( + f"Malformed BentoML spec received from BentoML server {server_url}" + ) + try: + api = InferenceAPI[t.Any]( + None, + io.from_spec( + meth_spec["requestBody"]["x-bentoml-io-descriptor"] + ), + io.from_spec( + meth_spec["responses"]["200"]["x-bentoml-io-descriptor"] + ), + name=meth_spec["x-bentoml-name"], + doc=meth_spec["description"], + route=route.lstrip("/"), + ) + dummy_service.apis[meth_spec["x-bentoml-name"]] = api + except BentoMLException as e: + logger.error( + "Failed to instantiate client for API %s: ", + meth_spec["x-bentoml-name"], + e, + ) + + return cls(dummy_service, server_url) + + async def _call( + self, inp: t.Any = None, *, _bentoml_api: InferenceAPI[t.Any], **kwargs: t.Any + ) -> t.Any: + # All gRPC kwargs should be poped out. + kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_grpc_")} + api = _bentoml_api + + if api.multi_input: + if inp is not None: + raise BentoMLException( + f"'{api.name}' takes multiple inputs; all inputs must be passed as keyword arguments." + ) + fake_resp = await api.input.to_http_response(kwargs, None) + else: + fake_resp = await api.input.to_http_response(inp, None) + + # TODO: Temporary workaround before moving everything to StreamingResponse + if isinstance(fake_resp, starlette.responses.StreamingResponse): + req_body = fake_resp.body + else: + req_body = fake_resp.body + + resp = await self.client.post( + api.route, + content=req_body, + headers={"content-type": fake_resp.headers["content-type"]}, + ) + if resp.status_code != 200: + raise BentoMLException( + f"Error making request: {resp.status_code}: {str(await resp.aread())}" + ) + + fake_req = starlette.requests.Request(scope={"type": "http"}) + headers = starlette.datastructures.Headers(headers=resp.headers) + fake_req._body = await resp.aread() + # Request.headers sets a _headers variable. We will need to set this + # value to our fake request object. + fake_req._headers = headers # type: ignore (request._headers is property) + + return await api.output.from_http_request(fake_req) + + async def close(self): + await self.client.aclose() + return await super().close() + + +class SyncHTTPClient(SyncClient): + @cached_property + def client(self) -> httpx.Client: + return httpx.Client(base_url=self.server_url, timeout=300) + @staticmethod def wait_until_server_ready( host: str, @@ -37,79 +197,52 @@ def wait_until_server_ready( **kwargs: t.Any, ) -> None: start_time = time.time() - status = None logger.debug("Waiting for host %s to be ready.", f"{host}:{port}") while time.time() - start_time < timeout: try: - conn = HTTPConnection(host, port) - conn.request("GET", "/readyz") - status = conn.getresponse().status + status = httpx.get(f"{host}:{port}/readyz").status_code if status == 200: break else: time.sleep(check_interval) except ( - ConnectionError, - urllib.error.URLError, - socket.timeout, - ConnectionRefusedError, + httpx.TimeoutException, + httpx.NetworkError, + httpx.HTTPStatusError, ): logger.debug("Server is not ready. Retrying...") - time.sleep(check_interval) # try to connect one more time and raise exception. try: - conn = HTTPConnection(host, port) - conn.request("GET", "/readyz") - status = conn.getresponse().status + status = httpx.get(f"{host}:{port}/readyz").status_code if status != 200: raise TimeoutError( f"Timed out waiting {timeout} seconds for server at '{host}:{port}' to be ready." ) except ( - ConnectionError, - urllib.error.URLError, - socket.timeout, - ConnectionRefusedError, - TimeoutError, + httpx.TimeoutException, + httpx.NetworkError, + httpx.HTTPStatusError, ) as err: logger.error("Timed out while connecting to %s:%s:", host, port) logger.error(err) raise - async def async_health(self) -> t.Any: - async with aiohttp.ClientSession(self.server_url) as sess: - async with sess.get("/readyz") as resp: - return resp - - def health(self) -> t.Any: - return asyncio.run(self.async_health()) + def health(self) -> httpx.Response: + return self.client.get("/readyz") @classmethod - def from_url(cls, server_url: str, **kwargs: t.Any) -> HTTPClient: + def from_url(cls, server_url: str, **kwargs: t.Any) -> SyncHTTPClient: server_url = server_url if "://" in server_url else "http://" + server_url - url_parts = urlparse(server_url) - - # TODO: SSL support - conn = HTTPConnection(url_parts.netloc) - conn.set_debuglevel(logging.DEBUG if get_debug_mode() else 0) - - # we want to preserve as much of the user path as possible, so we don't really want to use - # a path join here. - if url_parts.path.endswith("/"): - path = url_parts.path + "docs.json" - else: - path = url_parts.path + "/docs.json" - conn.request("GET", path) - resp = conn.getresponse() - if resp.status != 200: - raise RemoteException( - f"Failed to get OpenAPI schema from the server: {resp.status} {resp.reason}:\n{resp.read()}" - ) - openapi_spec = json.load(resp) - conn.close() + with httpx.Client(base_url=server_url) as session: + resp = session.get("docs.json") + if resp.status_code != 200: + raise RemoteException( + f"Failed to get OpenAPI schema from the server: {resp.status_code} {resp.reason_phrase}:\n{resp.content}" + ) + openapi_spec = json.loads(resp.content) dummy_service = Service(openapi_spec["info"]["title"]) @@ -152,7 +285,7 @@ def from_url(cls, server_url: str, **kwargs: t.Any) -> HTTPClient: return cls(dummy_service, server_url) - async def _call( + def _call( self, inp: t.Any = None, *, _bentoml_api: InferenceAPI[t.Any], **kwargs: t.Any ) -> t.Any: # All gRPC kwargs should be poped out. @@ -164,32 +297,40 @@ async def _call( raise BentoMLException( f"'{api.name}' takes multiple inputs; all inputs must be passed as keyword arguments." ) - fake_resp = await api.input.to_http_response(kwargs, None) + # TODO: remove asyncio run after descriptor rework + fake_resp = asyncio.run(api.input.to_http_response(kwargs, None)) else: - fake_resp = await api.input.to_http_response(inp, None) + fake_resp = asyncio.run(api.input.to_http_response(inp, None)) # TODO: Temporary workaround before moving everything to StreamingResponse if isinstance(fake_resp, starlette.responses.StreamingResponse): - req_body = "".join([s async for s in fake_resp.body_iterator]) + + async def get_body(): + return "".join([s async for s in fake_resp.body_iterator]) + + req_body = asyncio.run(get_body) else: req_body = fake_resp.body - async with aiohttp.ClientSession(self.server_url) as sess: - async with sess.post( - "/" + api.route if not api.route.startswith("/") else api.route, - data=req_body, - headers={"content-type": fake_resp.headers["content-type"]}, - ) as resp: - if resp.status != 200: - raise BentoMLException( - f"Error making request: {resp.status}: {str(await resp.read())}" - ) + resp = self.client.post( + api.route, + content=req_body, + headers={"content-type": fake_resp.headers["content-type"]}, + ) + if resp.status_code != 200: + raise BentoMLException( + f"Error making request: {resp.status_code}: {str(resp.content)}" + ) - fake_req = starlette.requests.Request(scope={"type": "http"}) - headers = starlette.datastructures.Headers(headers=resp.headers) - fake_req._body = await resp.read() - # Request.headers sets a _headers variable. We will need to set this - # value to our fake request object. - fake_req._headers = headers # type: ignore (request._headers is property) + fake_req = starlette.requests.Request(scope={"type": "http"}) + headers = starlette.datastructures.Headers(headers=resp.headers) + fake_req._body = resp.content + # Request.headers sets a _headers variable. We will need to set this + # value to our fake request object. + fake_req._headers = headers # type: ignore (request._headers is property) - return await api.output.from_http_request(fake_req) + return asyncio.run(api.output.from_http_request(fake_req)) + + def close(self): + self.client.close() + return super().close() diff --git a/src/bentoml/client.py b/src/bentoml/client.py index 29185094f82..1ce97dafda7 100644 --- a/src/bentoml/client.py +++ b/src/bentoml/client.py @@ -14,8 +14,24 @@ """ from __future__ import annotations +from ._internal.client import AsyncClient from ._internal.client import Client +from ._internal.client import SyncClient +from ._internal.client.grpc import AsyncGrpcClient from ._internal.client.grpc import GrpcClient +from ._internal.client.grpc import SyncGrpcClient +from ._internal.client.http import AsyncHTTPClient from ._internal.client.http import HTTPClient +from ._internal.client.http import SyncHTTPClient -__all__ = ["Client", "HTTPClient", "GrpcClient"] +__all__ = [ + "AsyncClient", + "SyncClient", + "Client", + "AsyncHTTPClient", + "SyncHTTPClient", + "HTTPClient", + "AsyncGrpcClient", + "SyncGrpcClient", + "GrpcClient", +]