Skip to content

Commit

Permalink
chore: upgrade pynumaflow version to v0.5 (#291)
Browse files Browse the repository at this point in the history
- upgrade pynumaflow version
- improve test coverage
- downgrade moto library to fix test failures

---------

Signed-off-by: Avik Basu <ab93@users.noreply.github.com>
  • Loading branch information
ab93 authored Sep 19, 2023
1 parent bc6137a commit 8ec5781
Show file tree
Hide file tree
Showing 15 changed files with 355 additions and 283 deletions.
2 changes: 1 addition & 1 deletion numalogic/udfs/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Union
from collections.abc import Coroutine
import numpy.typing as npt
from pynumaflow.function import Datum, Messages
from pynumaflow.mapper import Datum, Messages

from numalogic.tools.types import artifact_t

Expand Down
88 changes: 80 additions & 8 deletions numalogic/udfs/factory.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
# Copyright 2022 The Numaproj Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import ClassVar

from pynumaflow.mapper import Mapper, MultiProcMapper, AsyncMapper

from numalogic.udfs import NumalogicUDF
from numalogic.udfs.inference import InferenceUDF
from numalogic.udfs.trainer import TrainerUDF
from numalogic.udfs.preprocess import PreprocessUDF
from numalogic.udfs.postprocess import PostprocessUDF
from pynumaflow.function import Server, AsyncServer, MultiProcServer
from numalogic.udfs.preprocess import PreprocessUDF
from numalogic.udfs.trainer import TrainerUDF

_LOGGER = logging.getLogger(__name__)


class UDFFactory:
"""Factory class to fetch the right UDF."""

_UDF_MAP: ClassVar[dict] = {
_UDF_MAP: ClassVar[dict[str, type[NumalogicUDF]]] = {
"preprocess": PreprocessUDF,
"inference": InferenceUDF,
"postprocess": PostprocessUDF,
Expand All @@ -23,6 +35,21 @@ class UDFFactory:

@classmethod
def get_udf_cls(cls, udf_name: str) -> type[NumalogicUDF]:
"""
Get the UDF class.
Args:
udf_name: Name of the UDF;
possible values: preprocess, inference, postprocess, trainer
Returns
-------
UDF class
Raises
------
ValueError: If the UDF name is invalid
"""
try:
return cls._UDF_MAP[udf_name]
except KeyError as err:
Expand All @@ -32,21 +59,51 @@ def get_udf_cls(cls, udf_name: str) -> type[NumalogicUDF]:

@classmethod
def get_udf_instance(cls, udf_name: str, **kwargs) -> NumalogicUDF:
"""
Get the UDF instance.
Args:
udf_name: Name of the UDF;
possible values: preprocess, inference, postprocess, trainer
Returns
-------
UDF instance
Raises
------
ValueError: If the UDF name is invalid
"""
udf_cls = cls.get_udf_cls(udf_name)
return udf_cls(**kwargs)


class ServerFactory:
"""Factory class to fetch the right pynumaflow function server."""
"""Factory class to fetch the right pynumaflow function server/mapper."""

_SERVER_MAP: ClassVar[dict] = {
"sync": Server,
"async": AsyncServer,
"multiproc": MultiProcServer,
"sync": Mapper,
"async": AsyncMapper,
"multiproc": MultiProcMapper,
}

@classmethod
def get_server_cls(cls, server_name: str):
"""
Get the server class.
Args:
server_name: Name of the server;
possible values: sync, async, multiproc
Returns
-------
Server class
Raises
------
ValueError: If the server name is invalid
"""
try:
return cls._SERVER_MAP[server_name]
except KeyError as err:
Expand All @@ -56,5 +113,20 @@ def get_server_cls(cls, server_name: str):

@classmethod
def get_server_instance(cls, server_name: str, **kwargs):
"""
Get the server/mapper instance.
Args:
server_name: Name of the server;
possible values: sync, async, multiproc
Returns
-------
Server instance
Raises
------
ValueError: If the server name is invalid
"""
server_cls = cls.get_server_cls(server_name)
return server_cls(**kwargs)
2 changes: 1 addition & 1 deletion numalogic/udfs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from numpy import typing as npt
from orjson import orjson
from pynumaflow.function import Messages, Datum, Message
from pynumaflow.mapper import Messages, Datum, Message

from numalogic.config import RegistryFactory
from numalogic.registry import LocalLRUCache, ArtifactData
Expand Down
2 changes: 1 addition & 1 deletion numalogic/udfs/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from numpy.typing import NDArray
from orjson import orjson
from pynumaflow.function import Messages, Datum, Message
from pynumaflow.mapper import Messages, Datum, Message

from numalogic.config import PostprocessFactory, RegistryFactory
from numalogic.registry import LocalLRUCache
Expand Down
2 changes: 1 addition & 1 deletion numalogic/udfs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import orjson
from numpy._typing import NDArray
from pynumaflow.function import Datum, Messages, Message
from pynumaflow.mapper import Datum, Messages, Message
from sklearn.pipeline import make_pipeline

from numalogic.config import PreprocessFactory, RegistryFactory
Expand Down
2 changes: 1 addition & 1 deletion numalogic/udfs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy.typing as npt
import orjson
import pandas as pd
from pynumaflow.function import Datum, Messages, Message
from pynumaflow.mapper import Datum, Messages, Message
from sklearn.pipeline import make_pipeline
from torch.utils.data import DataLoader

Expand Down
Loading

0 comments on commit 8ec5781

Please sign in to comment.