Skip to content

Commit

Permalink
feat: maintain backwards compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
judahrand committed Aug 15, 2023
1 parent 52ff0d2 commit 34940ff
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 24 deletions.
182 changes: 164 additions & 18 deletions src/bentoml/_internal/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import typing as t
from abc import ABC
from abc import abstractproperty
from http.client import BadStatusLine

from ...exceptions import BentoMLException
Expand All @@ -17,8 +18,150 @@
from ..service import Service
from .grpc import AsyncGrpcClient
from .grpc import GrpcClient
from .grpc import SyncGrpcClient
from .http import AsyncHTTPClient
from .http import HTTPClient
from .http import SyncHTTPClient


class Client(ABC):
server_url: str
_svc: Service
endpoints: list[str]

def __init__(self, svc: Service, server_url: str):
self._svc = svc
self.server_url = server_url

if len(svc.apis) == 0:
raise BentoMLException("No APIs were found when constructing client.")

self.endpoints = []
for name in self._svc.apis.keys():
self.endpoints.append(name)

if not hasattr(self, name):
setattr(self, name, functools.partial(self.call, bentoml_api_name=name))

if not hasattr(self, f"async_{name}"):
setattr(
self,
f"async_{name}",
functools.partial(self.async_call, bentoml_api_name=name),
)

@abstractproperty
def sync_client(self) -> SyncClient:
raise NotImplementedError()

@abstractproperty
def async_client(self) -> AsyncClient:
raise NotImplementedError()

def call(
self,
bentoml_api_name: str,
inp: t.Any = None,
**kwargs: t.Any,
) -> t.Any:
return self.sync_client.call(
inp=inp, bentoml_api_name=bentoml_api_name, **kwargs
)

async def async_call(
self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any
) -> t.Any:
return await self.async_client.call(
inp=inp, bentoml_api_name=bentoml_api_name, **kwargs
)

@staticmethod
def wait_until_server_ready(
host: str, port: int, timeout: float = 30, **kwargs: t.Any
) -> None:
try:
from .http import SyncHTTPClient

SyncHTTPClient.wait_until_server_ready(host, port, timeout, **kwargs)
except BadStatusLine:
# when address is a RPC
from .grpc import SyncGrpcClient

SyncGrpcClient.wait_until_server_ready(host, port, timeout, **kwargs)
except Exception as err:
# caught all other exceptions
logger.error("Failed to connect to server %s:%s", host, port)
logger.error(err)
raise

@t.overload
@staticmethod
def from_url(
server_url: str, *, kind: None | t.Literal["auto"] = ...
) -> GrpcClient | HTTPClient:
...

@t.overload
@staticmethod
def from_url(server_url: str, *, kind: t.Literal["http"] = ...) -> HTTPClient:
...

@t.overload
@staticmethod
def from_url(server_url: str, *, kind: t.Literal["grpc"] = ...) -> GrpcClient:
...

@staticmethod
def from_url(
server_url: str, *, kind: str | None = None, **kwargs: t.Any
) -> Client:
if kind is None or kind == "auto":
try:
from .http import HTTPClient

return HTTPClient.from_url(server_url, **kwargs)
except BadStatusLine:
from .grpc import GrpcClient

return GrpcClient.from_url(server_url, **kwargs)
except Exception as e: # pylint: disable=broad-except
raise BentoMLException(
f"Failed to create a BentoML client from given URL '{server_url}': {e} ({e.__class__.__name__})"
) from e
elif kind == "http":
from .http import HTTPClient

return HTTPClient.from_url(server_url, **kwargs)
elif kind == "grpc":
from .grpc import GrpcClient

return GrpcClient.from_url(server_url, **kwargs)
else:
raise BentoMLException(
f"Invalid client kind '{kind}'. Must be one of 'http', 'grpc', or 'auto'."
)

def __enter__(self):
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> bool | None:
return self.sync_client.close()

async def __aenter__(self):
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> bool | None:
return await self.async_client.close()


class AsyncClient(ABC):
Expand Down Expand Up @@ -142,7 +285,7 @@ async def from_url(
)


class Client(ABC):
class SyncClient(ABC):
server_url: str
_svc: Service
endpoints: list[str]
Expand All @@ -154,7 +297,6 @@ def __init__(self, svc: Service, server_url: str):
if len(svc.apis) == 0:
raise BentoMLException("No APIs were found when constructing client.")

self.endpoints = []
self.endpoints = []
for name, api in self._svc.apis.items():
self.endpoints.append(name)
Expand Down Expand Up @@ -193,14 +335,14 @@ def wait_until_server_ready(
host: str, port: int, timeout: float = 30, **kwargs: t.Any
) -> None:
try:
from .http import HTTPClient
from .http import SyncHTTPClient

HTTPClient.wait_until_server_ready(host, port, timeout, **kwargs)
SyncHTTPClient.wait_until_server_ready(host, port, timeout, **kwargs)
except BadStatusLine:
# when address is a RPC
from .grpc import GrpcClient
from .grpc import SyncGrpcClient

GrpcClient.wait_until_server_ready(host, port, timeout, **kwargs)
SyncGrpcClient.wait_until_server_ready(host, port, timeout, **kwargs)
except Exception as err:
# caught all other exceptions
logger.error("Failed to connect to server %s:%s", host, port)
Expand All @@ -211,44 +353,48 @@ def wait_until_server_ready(
@classmethod
def from_url(
cls, server_url: str, *, kind: None | t.Literal["auto"] = ...
) -> GrpcClient | HTTPClient:
) -> SyncGrpcClient | SyncHTTPClient:
...

@t.overload
@classmethod
def from_url(cls, server_url: str, *, kind: t.Literal["http"] = ...) -> HTTPClient:
def from_url(
cls, server_url: str, *, kind: t.Literal["http"] = ...
) -> SyncHTTPClient:
...

@t.overload
@classmethod
def from_url(cls, server_url: str, *, kind: t.Literal["grpc"] = ...) -> GrpcClient:
def from_url(
cls, server_url: str, *, kind: t.Literal["grpc"] = ...
) -> SyncGrpcClient:
...

@classmethod
def from_url(
cls, server_url: str, *, kind: str | None = None, **kwargs: t.Any
) -> HTTPClient | GrpcClient:
) -> SyncHTTPClient | SyncGrpcClient:
if kind is None or kind == "auto":
try:
from .http import HTTPClient
from .http import SyncHTTPClient

return HTTPClient.from_url(server_url, **kwargs)
return SyncHTTPClient.from_url(server_url, **kwargs)
except BadStatusLine:
from .grpc import GrpcClient
from .grpc import SyncGrpcClient

return GrpcClient.from_url(server_url, **kwargs)
return SyncGrpcClient.from_url(server_url, **kwargs)
except Exception as e: # pylint: disable=broad-except
raise BentoMLException(
f"Failed to create a BentoML client from given URL '{server_url}': {e} ({e.__class__.__name__})"
) from e
elif kind == "http":
from .http import HTTPClient
from .http import SyncHTTPClient

return HTTPClient.from_url(server_url, **kwargs)
return SyncHTTPClient.from_url(server_url, **kwargs)
elif kind == "grpc":
from .grpc import GrpcClient
from .grpc import SyncGrpcClient

return GrpcClient.from_url(server_url, **kwargs)
return SyncGrpcClient.from_url(server_url, **kwargs)
else:
raise BentoMLException(
f"Invalid client kind '{kind}'. Must be one of 'http', 'grpc', or 'auto'."
Expand Down
20 changes: 17 additions & 3 deletions src/bentoml/_internal/client/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

from packaging.version import parse

from bentoml._internal.service.inference_api import InferenceAPI

from ...exceptions import BentoMLException
from ...grpc.utils import LATEST_PROTOCOL_VERSION
from ...grpc.utils import import_generated_stubs
Expand All @@ -22,6 +20,7 @@
from ..utils import LazyLoader
from . import AsyncClient
from . import Client
from . import SyncClient

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,6 +55,21 @@ class ClientCredentials(t.TypedDict):
)


class GrpcClient(Client):
def __init__(self, svc: Service, server_url: str):
self._sync_client = SyncGrpcClient(svc=svc, server_url=server_url)
self._async_client = AsyncGrpcClient(svc=svc, server_url=server_url)
super().__init__(svc, server_url)

@property
def sync_client(self) -> SyncGrpcClient:
return self._sync_client

@property
def async_client(self) -> AsyncGrpcClient:
return self._async_client


# TODO: xDS support
class AsyncGrpcClient(AsyncClient):
def __init__(
Expand Down Expand Up @@ -418,7 +432,7 @@ async def close(self):


# TODO: xDS support
class GrpcClient(Client):
class SyncGrpcClient(SyncClient):
def __init__(
self,
server_url: str,
Expand Down
20 changes: 17 additions & 3 deletions src/bentoml/_internal/client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import starlette.datastructures
import starlette.requests

from bentoml._internal.service import Service

from ...exceptions import BentoMLException
from ...exceptions import RemoteException
from .. import io_descriptors as io
Expand All @@ -27,10 +25,26 @@
from ..service.inference_api import InferenceAPI
from . import AsyncClient
from . import Client
from . import SyncClient

logger = logging.getLogger(__name__)


class HTTPClient(Client):
def __init__(self, svc: Service, server_url: str):
self._sync_client = SyncHTTPClient(svc=svc, server_url=server_url)
self._async_client = AsyncHTTPClient(svc=svc, server_url=server_url)
super().__init__(svc, server_url)

@property
def sync_client(self) -> SyncHTTPClient:
return self._sync_client

@property
def async_client(self) -> AsyncHTTPClient:
return self._async_client


class AsyncHTTPClient(AsyncClient):
@cached_property
def client(self) -> aiohttp.ClientSession:
Expand Down Expand Up @@ -178,7 +192,7 @@ async def close(self):
return await super().close()


class HTTPClient(Client):
class SyncHTTPClient(SyncClient):
@cached_property
def client(self) -> HTTPConnection:
server_url = urlparse(self.server_url)
Expand Down
4 changes: 4 additions & 0 deletions src/bentoml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@
from ._internal.client import Client
from ._internal.client.grpc import AsyncGrpcClient
from ._internal.client.grpc import GrpcClient
from ._internal.client.grpc import SyncGrpcClient
from ._internal.client.http import AsyncHTTPClient
from ._internal.client.http import HTTPClient
from ._internal.client.http import SyncHTTPClient

__all__ = [
"AsyncClient",
"Client",
"AsyncHTTPClient",
"SyncHTTPClient",
"HTTPClient",
"AsyncGrpcClient",
"SyncGrpcClient",
"GrpcClient",
]

0 comments on commit 34940ff

Please sign in to comment.