-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'develop' into mem_profile
- Loading branch information
Showing
13 changed files
with
468 additions
and
211 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,3 +53,4 @@ tabulate | |
|
||
# Model inference launcher from the dedicated inference server | ||
ovmsclient | ||
tritonclient[all] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
from datumaro.plugins.inference_server_plugin.ovms import OVMSLauncher | ||
from datumaro.plugins.inference_server_plugin.triton import TritonLauncher | ||
|
||
__all__ = ["OVMSLauncher", "TritonLauncher"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
from dataclasses import dataclass | ||
from enum import IntEnum | ||
from typing import Dict, Generic, Optional, TypeVar | ||
|
||
from grpc import ChannelCredentials, ssl_channel_credentials | ||
from ovmsclient.tfs_compat.base.serving_client import ServingClient | ||
|
||
from datumaro.components.errors import DatumaroError, MediaTypeError | ||
from datumaro.components.launcher import LauncherWithModelInterpreter | ||
from datumaro.components.media import Image | ||
|
||
|
||
class ProtocolType(IntEnum): | ||
"""Protocol type for communication with dedicated inference server""" | ||
|
||
grpc = 0 | ||
http = 1 | ||
|
||
|
||
@dataclass(frozen=True) | ||
class TLSConfig: | ||
"""TLS configuration dataclass | ||
Parameters: | ||
client_key_path: Path to client key file | ||
client_cert_path: Path to client certificate file | ||
server_cert_path: Path to server certificate file | ||
""" | ||
|
||
client_key_path: str | ||
client_cert_path: str | ||
server_cert_path: str | ||
|
||
def as_dict(self) -> Dict[str, str]: | ||
return { | ||
"client_key_path": self.client_key_path, | ||
"client_cert_path": self.client_cert_path, | ||
"server_cert_path": self.server_cert_path, | ||
} | ||
|
||
def as_grpc_creds(self) -> ChannelCredentials: | ||
server_cert, client_cert, client_key = ServingClient._prepare_certs( | ||
self.server_cert_path, self.client_cert_path, self.client_key_path | ||
) | ||
return ssl_channel_credentials( | ||
root_certificates=server_cert, private_key=client_key, certificate_chain=client_cert | ||
) | ||
|
||
|
||
TClient = TypeVar("TClient") | ||
|
||
|
||
class LauncherForDedicatedInferenceServer(Generic[TClient], LauncherWithModelInterpreter): | ||
"""Inference launcher for dedicated inference server | ||
Parameters: | ||
model_name: Name of the model. It should match with the model name loaded in the server instance. | ||
model_interpreter_path: Python source code path which implements a model interpreter. | ||
The model interpreter implement pre-processing of the model input and post-processing of the model output. | ||
model_version: Version of the model loaded in the server instance | ||
host: Host address of the server instance | ||
port: Port number of the server instance | ||
timeout: Timeout limit during communication between the client and the server instance | ||
tls_config: Configuration required if the server instance is in the secure mode | ||
protocol_type: Communication protocol type with the server instance | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_name: str, | ||
model_interpreter_path: str, | ||
model_version: int = 0, | ||
host: str = "localhost", | ||
port: int = 9000, | ||
timeout: float = 10.0, | ||
tls_config: Optional[TLSConfig] = None, | ||
protocol_type: ProtocolType = ProtocolType.grpc, | ||
): | ||
super().__init__(model_interpreter_path=model_interpreter_path) | ||
|
||
self.model_name = model_name | ||
self.model_version = model_version | ||
self.url = f"{host}:{port}" | ||
self.timeout = timeout | ||
self.tls_config = tls_config | ||
self.protocol_type = protocol_type | ||
|
||
try: | ||
self._client = self._init_client() | ||
self._check_server_health() | ||
self._init_metadata() | ||
except Exception as e: | ||
raise DatumaroError( | ||
f"Health check failed for model_name={self.model_name}, " | ||
f"model_version={self.model_version}, url={self.url} and tls_config={self.tls_config}" | ||
) from e | ||
|
||
def _init_client(self) -> TClient: | ||
raise NotImplementedError() | ||
|
||
def _check_server_health(self) -> None: | ||
raise NotImplementedError() | ||
|
||
def _init_metadata(self) -> None: | ||
raise NotImplementedError() | ||
|
||
def type_check(self, item): | ||
if not isinstance(item.media, Image): | ||
raise MediaTypeError(f"Media type should be Image, Current type={type(item.media)}") | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import logging as log | ||
from typing import List, Union | ||
|
||
import numpy as np | ||
from ovmsclient import make_grpc_client, make_http_client | ||
from ovmsclient.tfs_compat.grpc.serving_client import GrpcClient | ||
from ovmsclient.tfs_compat.http.serving_client import HttpClient | ||
|
||
from datumaro.components.abstracts.model_interpreter import ModelPred | ||
from datumaro.components.errors import DatumaroError | ||
from datumaro.plugins.inference_server_plugin.base import ( | ||
LauncherForDedicatedInferenceServer, | ||
ProtocolType, | ||
) | ||
|
||
__all__ = ["OVMSLauncher"] | ||
|
||
TClient = Union[GrpcClient, HttpClient] | ||
|
||
|
||
class OVMSLauncher(LauncherForDedicatedInferenceServer[TClient]): | ||
"""Inference launcher for OVMS (OpenVINO™ Model Server) (https://github.com/openvinotoolkit/model_server) | ||
Parameters: | ||
model_name: Name of the model. It should match with the model name loaded in the server instance. | ||
model_interpreter_path: Python source code path which implements a model interpreter. | ||
The model interpreter implement pre-processing of the model input and post-processing of the model output. | ||
model_version: Version of the model loaded in the server instance | ||
host: Host address of the server instance | ||
port: Port number of the server instance | ||
timeout: Timeout limit during communication between the client and the server instance | ||
tls_config: Configuration required if the server instance is in the secure mode | ||
protocol_type: Communication protocol type with the server instance | ||
""" | ||
|
||
def _init_client(self) -> TClient: | ||
tls_config = self.tls_config.as_dict() if self.tls_config is not None else None | ||
|
||
if self.protocol_type == ProtocolType.grpc: | ||
return make_grpc_client(self.url, tls_config) | ||
if self.protocol_type == ProtocolType.http: | ||
return make_http_client(self.url, tls_config) | ||
|
||
raise NotImplementedError(self.protocol_type) | ||
|
||
def _check_server_health(self) -> None: | ||
status = self._client.get_model_status( | ||
model_name=self.model_name, | ||
model_version=self.model_version, | ||
timeout=self.timeout, | ||
) | ||
log.info(f"Health check succeeded: {status}") | ||
|
||
def _init_metadata(self): | ||
self._metadata = self._client.get_model_metadata( | ||
model_name=self.model_name, | ||
model_version=self.model_version, | ||
timeout=self.timeout, | ||
) | ||
log.info(f"Received metadata: {self._metadata}") | ||
|
||
def infer(self, inputs: np.ndarray) -> List[ModelPred]: | ||
# Please see the following link for the input and output type of self._client.predict() | ||
# https://github.com/openvinotoolkit/model_server/blob/releases/2022/3/client/python/ovmsclient/lib/docs/grpc_client.md#method-predict | ||
# The input is Dict[str, np.ndarray]. | ||
# The output is Dict[str, np.ndarray] (If the model has multiple outputs), | ||
# or np.ndarray (If the model has one single output). | ||
results = self._client.predict( | ||
inputs={self._input_key: inputs}, | ||
model_name=self.model_name, | ||
model_version=self.model_version, | ||
timeout=self.timeout, | ||
) | ||
|
||
# If there is only one output key, | ||
# it returns `np.ndarray`` rather than `Dict[str, np.ndarray]`. | ||
# Please see ovmsclient.tfs_compat.grpc.responses.GrpcPredictResponse | ||
if isinstance(results, np.ndarray): | ||
results = {self._output_key: results} | ||
|
||
outputs_group_by_item = [ | ||
{key: output for key, output in zip(results.keys(), outputs)} | ||
for outputs in zip(*results.values()) | ||
] | ||
|
||
return outputs_group_by_item | ||
|
||
@property | ||
def _input_key(self): | ||
if hasattr(self, "__input_key"): | ||
return self.__input_key | ||
|
||
metadata_inputs = self._metadata.get("inputs") | ||
|
||
if metadata_inputs is None: | ||
raise DatumaroError("Cannot get metadata of the outputs.") | ||
|
||
if len(metadata_inputs.keys()) > 1: | ||
raise DatumaroError( | ||
f"More than two model inputs are not allowed: {metadata_inputs.keys()}." | ||
) | ||
|
||
self.__input_key = next(iter(metadata_inputs.keys())) | ||
return self.__input_key | ||
|
||
@property | ||
def _output_key(self): | ||
if hasattr(self, "__output_key"): | ||
return self.__output_key | ||
|
||
metadata_outputs = self._metadata.get("outputs") | ||
|
||
if metadata_outputs is None: | ||
raise DatumaroError("Cannot get metadata of the outputs.") | ||
|
||
if len(metadata_outputs.keys()) > 1: | ||
raise DatumaroError( | ||
f"More than two model outputs are not allowed: {metadata_outputs.keys()}." | ||
) | ||
|
||
self.__output_key = next(iter(metadata_outputs.keys())) | ||
return self.__output_key |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import logging as log | ||
from typing import Dict, List, Type, Union | ||
|
||
import numpy as np | ||
import tritonclient.grpc as grpcclient | ||
import tritonclient.http as httpclient | ||
|
||
from datumaro.components.abstracts.model_interpreter import ModelPred | ||
from datumaro.components.errors import DatumaroError | ||
from datumaro.plugins.inference_server_plugin.base import ( | ||
LauncherForDedicatedInferenceServer, | ||
ProtocolType, | ||
) | ||
|
||
__all__ = ["TritonLauncher"] | ||
|
||
TClient = Union[grpcclient.InferenceServerClient, httpclient.InferenceServerClient] | ||
TInferInput = Union[grpcclient.InferInput, httpclient.InferInput] | ||
TInferOutput = Union[grpcclient.InferResult, httpclient.InferResult] | ||
|
||
|
||
class TritonLauncher(LauncherForDedicatedInferenceServer[TClient]): | ||
"""Inference launcher for Triton Inference Server (https://github.com/triton-inference-server) | ||
Parameters: | ||
model_name: Name of the model. It should match with the model name loaded in the server instance. | ||
model_interpreter_path: Python source code path which implements a model interpreter. | ||
The model interpreter implement pre-processing of the model input and post-processing of the model output. | ||
model_version: Version of the model loaded in the server instance | ||
host: Host address of the server instance | ||
port: Port number of the server instance | ||
timeout: Timeout limit during communication between the client and the server instance | ||
tls_config: Configuration required if the server instance is in the secure mode | ||
protocol_type: Communication protocol type with the server instance | ||
""" | ||
|
||
def _init_client(self) -> TClient: | ||
creds = self.tls_config.as_grpc_creds() if self.tls_config is not None else None | ||
|
||
if self.protocol_type == ProtocolType.grpc: | ||
return grpcclient.InferenceServerClient(url=self.url, creds=creds) | ||
if self.protocol_type == ProtocolType.http: | ||
return httpclient.InferenceServerClient(url=self.url) | ||
|
||
raise NotImplementedError(self.protocol_type) | ||
|
||
def _check_server_health(self) -> None: | ||
status = self._client.is_model_ready( | ||
model_name=self.model_name, | ||
model_version=str(self.model_version), | ||
) | ||
if not status: | ||
raise DatumaroError("Model is not ready.") | ||
log.info(f"Health check succeeded: {status}") | ||
|
||
def _init_metadata(self) -> None: | ||
self._metadata = self._client.get_model_metadata( | ||
model_name=self.model_name, | ||
model_version=str(self.model_version), | ||
) | ||
log.info(f"Received metadata: {self._metadata}") | ||
|
||
def _get_infer_input(self, inputs: np.ndarray) -> TInferInput: | ||
def _fix_dynamic_batch_dim(shape): | ||
if shape[0] == -1: | ||
shape[0] = inputs.shape[0] | ||
return shape | ||
|
||
def _create(infer_input_cls: Type[TInferInput]) -> TInferInput: | ||
infer_inputs = [ | ||
infer_input_cls( | ||
name=inp.name, | ||
shape=_fix_dynamic_batch_dim(inp.shape), | ||
datatype=inp.datatype, | ||
) | ||
for inp in self._metadata.inputs | ||
] | ||
for infer_input in infer_inputs: | ||
infer_input.set_data_from_numpy(inputs) | ||
return infer_inputs | ||
|
||
if self.protocol_type == ProtocolType.grpc: | ||
return _create(grpcclient.InferInput) | ||
if self.protocol_type == ProtocolType.http: | ||
return _create(httpclient.InferInput) | ||
|
||
raise NotImplementedError(self.protocol_type) | ||
|
||
def infer(self, inputs: np.ndarray) -> List[ModelPred]: | ||
infer_outputs: TInferOutput = self._client.infer( | ||
inputs=self._get_infer_input(inputs), | ||
model_name=self.model_name, | ||
model_version=str(self.model_version), | ||
) | ||
|
||
results: Dict[str, np.ndarray] = { | ||
output.name: infer_outputs.as_numpy(name=output.name) | ||
for output in self._metadata.outputs | ||
} | ||
|
||
outputs_group_by_item = [ | ||
{key: output for key, output in zip(results.keys(), outputs)} | ||
for outputs in zip(*results.values()) | ||
] | ||
|
||
return outputs_group_by_item |
Oops, something went wrong.