Skip to content

Commit

Permalink
Merge branch 'develop' into mem_profile
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjuleee authored Jul 3, 2023
2 parents 04dee1c + 0758a58 commit f93a7d4
Show file tree
Hide file tree
Showing 13 changed files with 468 additions and 211 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/1049>, <https://github.com/openvinotoolkit/datumaro/pull/1063>)
- Add OVMSLauncher
(<https://github.com/openvinotoolkit/datumaro/pull/1056>)
- Add TritonLauncher
(<https://github.com/openvinotoolkit/datumaro/pull/1059>)

### Enhancements
- Enhance import performance for built-in plugins
Expand Down
1 change: 1 addition & 0 deletions requirements-core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,4 @@ tabulate

# Model inference launcher from the dedicated inference server
ovmsclient
tritonclient[all]
8 changes: 8 additions & 0 deletions src/datumaro/plugins/inference_server_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

from datumaro.plugins.inference_server_plugin.ovms import OVMSLauncher
from datumaro.plugins.inference_server_plugin.triton import TritonLauncher

__all__ = ["OVMSLauncher", "TritonLauncher"]
114 changes: 114 additions & 0 deletions src/datumaro/plugins/inference_server_plugin/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

from dataclasses import dataclass
from enum import IntEnum
from typing import Dict, Generic, Optional, TypeVar

from grpc import ChannelCredentials, ssl_channel_credentials
from ovmsclient.tfs_compat.base.serving_client import ServingClient

from datumaro.components.errors import DatumaroError, MediaTypeError
from datumaro.components.launcher import LauncherWithModelInterpreter
from datumaro.components.media import Image


class ProtocolType(IntEnum):
"""Protocol type for communication with dedicated inference server"""

grpc = 0
http = 1


@dataclass(frozen=True)
class TLSConfig:
"""TLS configuration dataclass
Parameters:
client_key_path: Path to client key file
client_cert_path: Path to client certificate file
server_cert_path: Path to server certificate file
"""

client_key_path: str
client_cert_path: str
server_cert_path: str

def as_dict(self) -> Dict[str, str]:
return {
"client_key_path": self.client_key_path,
"client_cert_path": self.client_cert_path,
"server_cert_path": self.server_cert_path,
}

def as_grpc_creds(self) -> ChannelCredentials:
server_cert, client_cert, client_key = ServingClient._prepare_certs(
self.server_cert_path, self.client_cert_path, self.client_key_path
)
return ssl_channel_credentials(
root_certificates=server_cert, private_key=client_key, certificate_chain=client_cert
)


TClient = TypeVar("TClient")


class LauncherForDedicatedInferenceServer(Generic[TClient], LauncherWithModelInterpreter):
"""Inference launcher for dedicated inference server
Parameters:
model_name: Name of the model. It should match with the model name loaded in the server instance.
model_interpreter_path: Python source code path which implements a model interpreter.
The model interpreter implement pre-processing of the model input and post-processing of the model output.
model_version: Version of the model loaded in the server instance
host: Host address of the server instance
port: Port number of the server instance
timeout: Timeout limit during communication between the client and the server instance
tls_config: Configuration required if the server instance is in the secure mode
protocol_type: Communication protocol type with the server instance
"""

def __init__(
self,
model_name: str,
model_interpreter_path: str,
model_version: int = 0,
host: str = "localhost",
port: int = 9000,
timeout: float = 10.0,
tls_config: Optional[TLSConfig] = None,
protocol_type: ProtocolType = ProtocolType.grpc,
):
super().__init__(model_interpreter_path=model_interpreter_path)

self.model_name = model_name
self.model_version = model_version
self.url = f"{host}:{port}"
self.timeout = timeout
self.tls_config = tls_config
self.protocol_type = protocol_type

try:
self._client = self._init_client()
self._check_server_health()
self._init_metadata()
except Exception as e:
raise DatumaroError(
f"Health check failed for model_name={self.model_name}, "
f"model_version={self.model_version}, url={self.url} and tls_config={self.tls_config}"
) from e

def _init_client(self) -> TClient:
raise NotImplementedError()

def _check_server_health(self) -> None:
raise NotImplementedError()

def _init_metadata(self) -> None:
raise NotImplementedError()

def type_check(self, item):
if not isinstance(item.media, Image):
raise MediaTypeError(f"Media type should be Image, Current type={type(item.media)}")
return True
126 changes: 126 additions & 0 deletions src/datumaro/plugins/inference_server_plugin/ovms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import logging as log
from typing import List, Union

import numpy as np
from ovmsclient import make_grpc_client, make_http_client
from ovmsclient.tfs_compat.grpc.serving_client import GrpcClient
from ovmsclient.tfs_compat.http.serving_client import HttpClient

from datumaro.components.abstracts.model_interpreter import ModelPred
from datumaro.components.errors import DatumaroError
from datumaro.plugins.inference_server_plugin.base import (
LauncherForDedicatedInferenceServer,
ProtocolType,
)

__all__ = ["OVMSLauncher"]

TClient = Union[GrpcClient, HttpClient]


class OVMSLauncher(LauncherForDedicatedInferenceServer[TClient]):
"""Inference launcher for OVMS (OpenVINO™ Model Server) (https://github.com/openvinotoolkit/model_server)
Parameters:
model_name: Name of the model. It should match with the model name loaded in the server instance.
model_interpreter_path: Python source code path which implements a model interpreter.
The model interpreter implement pre-processing of the model input and post-processing of the model output.
model_version: Version of the model loaded in the server instance
host: Host address of the server instance
port: Port number of the server instance
timeout: Timeout limit during communication between the client and the server instance
tls_config: Configuration required if the server instance is in the secure mode
protocol_type: Communication protocol type with the server instance
"""

def _init_client(self) -> TClient:
tls_config = self.tls_config.as_dict() if self.tls_config is not None else None

if self.protocol_type == ProtocolType.grpc:
return make_grpc_client(self.url, tls_config)
if self.protocol_type == ProtocolType.http:
return make_http_client(self.url, tls_config)

raise NotImplementedError(self.protocol_type)

def _check_server_health(self) -> None:
status = self._client.get_model_status(
model_name=self.model_name,
model_version=self.model_version,
timeout=self.timeout,
)
log.info(f"Health check succeeded: {status}")

def _init_metadata(self):
self._metadata = self._client.get_model_metadata(
model_name=self.model_name,
model_version=self.model_version,
timeout=self.timeout,
)
log.info(f"Received metadata: {self._metadata}")

def infer(self, inputs: np.ndarray) -> List[ModelPred]:
# Please see the following link for the input and output type of self._client.predict()
# https://github.com/openvinotoolkit/model_server/blob/releases/2022/3/client/python/ovmsclient/lib/docs/grpc_client.md#method-predict
# The input is Dict[str, np.ndarray].
# The output is Dict[str, np.ndarray] (If the model has multiple outputs),
# or np.ndarray (If the model has one single output).
results = self._client.predict(
inputs={self._input_key: inputs},
model_name=self.model_name,
model_version=self.model_version,
timeout=self.timeout,
)

# If there is only one output key,
# it returns `np.ndarray`` rather than `Dict[str, np.ndarray]`.
# Please see ovmsclient.tfs_compat.grpc.responses.GrpcPredictResponse
if isinstance(results, np.ndarray):
results = {self._output_key: results}

outputs_group_by_item = [
{key: output for key, output in zip(results.keys(), outputs)}
for outputs in zip(*results.values())
]

return outputs_group_by_item

@property
def _input_key(self):
if hasattr(self, "__input_key"):
return self.__input_key

metadata_inputs = self._metadata.get("inputs")

if metadata_inputs is None:
raise DatumaroError("Cannot get metadata of the outputs.")

if len(metadata_inputs.keys()) > 1:
raise DatumaroError(
f"More than two model inputs are not allowed: {metadata_inputs.keys()}."
)

self.__input_key = next(iter(metadata_inputs.keys()))
return self.__input_key

@property
def _output_key(self):
if hasattr(self, "__output_key"):
return self.__output_key

metadata_outputs = self._metadata.get("outputs")

if metadata_outputs is None:
raise DatumaroError("Cannot get metadata of the outputs.")

if len(metadata_outputs.keys()) > 1:
raise DatumaroError(
f"More than two model outputs are not allowed: {metadata_outputs.keys()}."
)

self.__output_key = next(iter(metadata_outputs.keys()))
return self.__output_key
110 changes: 110 additions & 0 deletions src/datumaro/plugins/inference_server_plugin/triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import logging as log
from typing import Dict, List, Type, Union

import numpy as np
import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient

from datumaro.components.abstracts.model_interpreter import ModelPred
from datumaro.components.errors import DatumaroError
from datumaro.plugins.inference_server_plugin.base import (
LauncherForDedicatedInferenceServer,
ProtocolType,
)

__all__ = ["TritonLauncher"]

TClient = Union[grpcclient.InferenceServerClient, httpclient.InferenceServerClient]
TInferInput = Union[grpcclient.InferInput, httpclient.InferInput]
TInferOutput = Union[grpcclient.InferResult, httpclient.InferResult]


class TritonLauncher(LauncherForDedicatedInferenceServer[TClient]):
"""Inference launcher for Triton Inference Server (https://github.com/triton-inference-server)
Parameters:
model_name: Name of the model. It should match with the model name loaded in the server instance.
model_interpreter_path: Python source code path which implements a model interpreter.
The model interpreter implement pre-processing of the model input and post-processing of the model output.
model_version: Version of the model loaded in the server instance
host: Host address of the server instance
port: Port number of the server instance
timeout: Timeout limit during communication between the client and the server instance
tls_config: Configuration required if the server instance is in the secure mode
protocol_type: Communication protocol type with the server instance
"""

def _init_client(self) -> TClient:
creds = self.tls_config.as_grpc_creds() if self.tls_config is not None else None

if self.protocol_type == ProtocolType.grpc:
return grpcclient.InferenceServerClient(url=self.url, creds=creds)
if self.protocol_type == ProtocolType.http:
return httpclient.InferenceServerClient(url=self.url)

raise NotImplementedError(self.protocol_type)

def _check_server_health(self) -> None:
status = self._client.is_model_ready(
model_name=self.model_name,
model_version=str(self.model_version),
)
if not status:
raise DatumaroError("Model is not ready.")
log.info(f"Health check succeeded: {status}")

def _init_metadata(self) -> None:
self._metadata = self._client.get_model_metadata(
model_name=self.model_name,
model_version=str(self.model_version),
)
log.info(f"Received metadata: {self._metadata}")

def _get_infer_input(self, inputs: np.ndarray) -> TInferInput:
def _fix_dynamic_batch_dim(shape):
if shape[0] == -1:
shape[0] = inputs.shape[0]
return shape

def _create(infer_input_cls: Type[TInferInput]) -> TInferInput:
infer_inputs = [
infer_input_cls(
name=inp.name,
shape=_fix_dynamic_batch_dim(inp.shape),
datatype=inp.datatype,
)
for inp in self._metadata.inputs
]
for infer_input in infer_inputs:
infer_input.set_data_from_numpy(inputs)
return infer_inputs

if self.protocol_type == ProtocolType.grpc:
return _create(grpcclient.InferInput)
if self.protocol_type == ProtocolType.http:
return _create(httpclient.InferInput)

raise NotImplementedError(self.protocol_type)

def infer(self, inputs: np.ndarray) -> List[ModelPred]:
infer_outputs: TInferOutput = self._client.infer(
inputs=self._get_infer_input(inputs),
model_name=self.model_name,
model_version=str(self.model_version),
)

results: Dict[str, np.ndarray] = {
output.name: infer_outputs.as_numpy(name=output.name)
for output in self._metadata.outputs
}

outputs_group_by_item = [
{key: output for key, output in zip(results.keys(), outputs)}
for outputs in zip(*results.values())
]

return outputs_group_by_item
Loading

0 comments on commit f93a7d4

Please sign in to comment.