diff --git a/docs/assets/anomalydetection.png b/docs/assets/anomalydetection.png new file mode 100644 index 00000000..f8407386 Binary files /dev/null and b/docs/assets/anomalydetection.png differ diff --git a/numalogic/config/_conn.py b/numalogic/config/_conn.py new file mode 100644 index 00000000..e8f1623f --- /dev/null +++ b/numalogic/config/_conn.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass, field + + +@dataclass +class PrometheusConf: + server: str + pushgateway: str + + +@dataclass +class RedisConf: + host: str + port: int + expiry: int = 300 + master_name: str = "mymaster" + + +@dataclass +class DruidConf: + url: str + endpoint: str + + +@dataclass +class Pivot: + index: str = "timestamp" + columns: list[str] = field(default_factory=list) + value: list[str] = field(default_factory=lambda: ["count"]) + + +@dataclass +class DruidFetcherConf: + from pydruid.utils.aggregators import doublesum + + datasource: str + dimensions: list[str] = field(default_factory=list) + aggregations: dict = field(default_factory=dict) + group_by: list[str] = field(default_factory=list) + pivot: Pivot = field(default_factory=lambda: Pivot()) + granularity: str = "minute" + hours: float = 36 + + def __post_init__(self): + if not self.aggregations: + self.aggregations = {"count": self.doublesum("count")} diff --git a/numalogic/config/_metric.py b/numalogic/config/_metric.py new file mode 100644 index 00000000..063d9ce7 --- /dev/null +++ b/numalogic/config/_metric.py @@ -0,0 +1,65 @@ +from dataclasses import dataclass, field +from enum import Enum + +from omegaconf import MISSING + +from numalogic.config._conn import DruidFetcherConf, RedisConf, DruidConf, PrometheusConf +from numalogic.config._numa import NumalogicConf + + +@dataclass +class UnifiedConf: + strategy: str = "max" + weights: list[float] = field(default_factory=list) + + +@dataclass +class ReTrainConf: + train_hours: int = 36 + min_train_size: int = 2000 + retrain_freq_hr: int = 8 + resume_training: bool = False + + +@dataclass +class StaticThresholdConf: + upper_limit: int = 3 + weight: float = 0.0 + + +@dataclass +class MetricConf: + metric: str + retrain_conf: ReTrainConf = field(default_factory=lambda: ReTrainConf()) + static_threshold: StaticThresholdConf = field(default_factory=lambda: StaticThresholdConf()) + numalogic_conf: NumalogicConf = MISSING + + +class DataSource(str, Enum): + PROMETHEUS = "prometheus" + DRUID = "druid" + + +@dataclass +class DataStreamConf: + name: str = "default" + source: str = DataSource.PROMETHEUS.value + window_size: int = 12 + composite_keys: list[str] = field(default_factory=list) + metrics: list[str] = field(default_factory=list) + metric_configs: list[MetricConf] = field(default_factory=list) + unified_config: UnifiedConf = field(default_factory=lambda: UnifiedConf()) + druid_fetcher: DruidFetcherConf = MISSING + + +@dataclass +class Configs: + configs: list[DataStreamConf] + + +@dataclass +class PipelineConf: + redis_conf: RedisConf + # registry_conf: RegistryConf + prometheus_conf: PrometheusConf = MISSING + druid_conf: DruidConf = MISSING diff --git a/numalogic/config/_numa.py b/numalogic/config/_numa.py new file mode 100644 index 00000000..54523bd5 --- /dev/null +++ b/numalogic/config/_numa.py @@ -0,0 +1,79 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass, field +from typing import Optional, Any + +from omegaconf import MISSING + + +@dataclass +class ModelInfo: + """Schema for defining the model/estimator. + + Args: + ---- + name: name of the model; this should map to a supported list of models + mentioned in the factory file + conf: kwargs for instantiating the model class + stateful: flag indicating if the model is stateful or not + """ + + name: str = MISSING + conf: dict[str, Any] = field(default_factory=dict) + stateful: bool = True + + +@dataclass +class RegistryInfo: + """Registry config base class. + + Args: + ---- + name: name of the registry + conf: kwargs for instantiating the model class + """ + + name: str = MISSING + conf: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class LightningTrainerConf: + """Schema for defining the Pytorch Lightning trainer behavior. + + More details on the arguments are provided here: + https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-class-api + """ + + accelerator: str = "auto" + max_epochs: int = 100 + logger: bool = False + check_val_every_n_epoch: int = 5 + log_every_n_steps: int = 20 + enable_checkpointing: bool = False + enable_progress_bar: bool = True + enable_model_summary: bool = True + limit_val_batches: bool = 0 + callbacks: Optional[Any] = None + + +@dataclass +class NumalogicConf: + """Top level config schema for numalogic.""" + + model: ModelInfo = field(default_factory=ModelInfo) + trainer: LightningTrainerConf = field(default_factory=LightningTrainerConf) + registry: RegistryInfo = field(default_factory=RegistryInfo) + preprocess: list[ModelInfo] = field(default_factory=list) + threshold: ModelInfo = field(default_factory=ModelInfo) + postprocess: ModelInfo = field(default_factory=ModelInfo) diff --git a/numalogic/connectors/__init__.py b/numalogic/connectors/__init__.py new file mode 100644 index 00000000..3cb0e26d --- /dev/null +++ b/numalogic/connectors/__init__.py @@ -0,0 +1,4 @@ +from numalogic.connectors.druid import DruidFetcher +from numalogic.connectors.prometheus import PrometheusDataFetcher + +__all__ = ["DruidFetcher", "PrometheusDataFetcher"] diff --git a/numalogic/numaflow/entities.py b/numalogic/numaflow/entities.py new file mode 100644 index 00000000..d78dbb33 --- /dev/null +++ b/numalogic/numaflow/entities.py @@ -0,0 +1,149 @@ +from copy import copy +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Union, TypeVar + +import numpy as np +import numpy.typing as npt +import orjson +import pandas as pd + +Vector = list[float] +Matrix = Union[Vector, list[Vector], npt.NDArray[float]] + + +class Status(str, Enum): + """Status is the enum that is used to identify the status of the payload.""" + + RAW = "raw" + EXTRACTED = "extracted" + PRE_PROCESSED = "pre_processed" + INFERRED = "inferred" + THRESHOLD = "threshold_complete" + POST_PROCESSED = "post_processed" + ARTIFACT_NOT_FOUND = "artifact_not_found" + ARTIFACT_STALE = "artifact_is_stale" + RUNTIME_ERROR = "runtime_error" + + +class Header(str, Enum): + """Header is the enum that is used to identify the type of payload.""" + + STATIC_INFERENCE = "static_threshold" + MODEL_INFERENCE = "model_inference" + TRAIN_REQUEST = "request_training" + MODEL_STALE = "model_stale" + + +@dataclass +class _BasePayload: + """_BasePayload is the base data structure that is passed.""" + + uuid: str + composite_keys: list[str] + + +PayloadType = TypeVar("PayloadType", bound=_BasePayload) + + +@dataclass +class TrainerPayload(_BasePayload): + """ + TrainerPayload is the data structure that is passed + around in the system when a training request is made. + """ + + metrics: list[str] + header: Header = Header.TRAIN_REQUEST + + def to_json(self): + return orjson.dumps(self) + + +@dataclass(repr=False) +class StreamPayload(_BasePayload): + """StreamPayload is the main data structure that is passed around in the system.""" + + data: Matrix + raw_data: Matrix + metrics: list[str] + timestamps: list[int] + status: dict[str, Status] = field(default_factory=dict) + header: dict[str, Header] = field(default_factory=dict) + metadata: dict[str, dict[str, Any]] = field(default_factory=dict) + + def get_df(self, original=False) -> pd.DataFrame: + return pd.DataFrame(self.get_data(original), columns=self.metrics) + + def set_data(self, arr: Matrix) -> None: + self.data = arr + + def set_metric_data(self, metric: str, arr: Matrix) -> None: + _df = self.get_df().copy() + _df[metric] = arr + self.set_data(np.asarray(_df.values.tolist())) + + def get_metric_arr(self, metric: str) -> npt.NDArray[float]: + return self.get_df()[metric].values + + def get_data(self, original=False) -> npt.NDArray[float]: + if original: + return np.asarray(self.raw_data) + return np.asarray(self.data) + + def set_status(self, metric: str, status: Status) -> None: + self.status[metric] = status + + def set_header(self, metric: str, header: Header) -> None: + self.header[metric] = header + + def set_metric_metadata(self, metric: str, key: str, value) -> None: + if metric in self.metadata.keys(): + self.metadata[metric][key] = value + else: + self.metadata[metric] = {key: value} + + def set_metadata(self, key: str, value) -> None: + self.metadata[key] = value + + def get_metadata(self, key: str) -> dict[str, Any]: + return copy(self.metadata[key]) + + def __repr__(self) -> str: + """Return a string representation of the object.""" + return "header: {}, status: {}, composite_keys: {}, data: {}, metadata: {}}}".format( + self.header, + self.status, + self.composite_keys, + list(self.data), + self.metadata, + ) + + def to_json(self): + return orjson.dumps(self, option=orjson.OPT_SERIALIZE_NUMPY) + + +@dataclass +class InputPayload: + """Input payload.""" + + start_time: int + end_time: int + data: list[dict[str, Any]] + metadata: dict[str, Any] + + def to_json(self): + return orjson.dumps(self, option=orjson.OPT_SERIALIZE_NUMPY) + + +@dataclass +class OutputPayload: + """Output payload.""" + + timestamp: int + unified_anomaly: float + data: dict[str, Any] + metadata: dict[str, Any] + + def to_json(self): + return orjson.dumps(self, option=orjson.OPT_SERIALIZE_NUMPY) diff --git a/numalogic/numaflow/server.py b/numalogic/numaflow/server.py new file mode 100644 index 00000000..e69de29b diff --git a/numalogic/numaflow/udf/__init__.py b/numalogic/numaflow/udf/__init__.py new file mode 100644 index 00000000..79ec349b --- /dev/null +++ b/numalogic/numaflow/udf/__init__.py @@ -0,0 +1,4 @@ +from numalogic.numaflow.udf.blockinfer import InferenceBlockUDF +from numalogic.numaflow.udf.block_train import TrainBlockUDF + +__all__ = ["InferenceBlockUDF", "TrainBlockUDF"] diff --git a/numalogic/numaflow/udf/block_train.py b/numalogic/numaflow/udf/block_train.py new file mode 100644 index 00000000..0bb77fea --- /dev/null +++ b/numalogic/numaflow/udf/block_train.py @@ -0,0 +1,136 @@ +import logging +import os + +import pandas as pd +from omegaconf import OmegaConf +from orjson import orjson +from pynumaflow.function import Datum, Messages, Message +from sklearn.preprocessing import StandardScaler + +from numalogic.blocks import ( + PreprocessBlock, + NNBlock, + ThresholdBlock, + PostprocessBlock, + BlockPipeline, +) +from numalogic.config import ( + NumalogicConf, + RedisConf, + DataStreamConf, + LightningTrainerConf, + DruidConf, +) +from numalogic.connectors import DruidFetcher +from numalogic.models.autoencoder.variants import SparseVanillaAE +from numalogic.models.threshold import StdDevThreshold +from numalogic.numaflow import NumalogicUDF +from numalogic.numaflow.entities import TrainerPayload +from numalogic.registry import RedisRegistry +from numalogic.registry.redis_registry import get_redis_client_from_conf +from numalogic.transforms import TanhNorm + +_LOGGER = logging.getLogger(__name__) +REQUEST_EXPIRY = int(os.getenv("REQUEST_EXPIRY") or 300) +MIN_RECORDS = 1000 + + +class TrainBlockUDF(NumalogicUDF): + """UDF to train the ML model.""" + + def __init__( + self, + numalogic_conf: NumalogicConf, + redis_conf: RedisConf, + stream_conf: DataStreamConf, + trainer_conf: LightningTrainerConf, + druid_conf: DruidConf, + ): + super().__init__() + self.conf = numalogic_conf + self.druid_conf = druid_conf + self.stream_conf = stream_conf + self.trainer_conf = trainer_conf + self._rclient = self._get_redis_client(redis_conf) + self.model_registry = RedisRegistry(client=self._rclient) + self._blocks_args = ( + PreprocessBlock(StandardScaler()), + NNBlock( + SparseVanillaAE( + seq_len=self.stream_conf.window_size, n_features=len(self.stream_conf.metrics) + ), + self.stream_conf.window_size, + ), + ThresholdBlock(StdDevThreshold()), + PostprocessBlock(TanhNorm()), + ) + self._dkeys = ("stdscaler", "sparsevanillae", "stddevthreshold") + + @staticmethod + def _get_redis_client(redis_conf: RedisConf): + return get_redis_client_from_conf(redis_conf) + + @staticmethod + def _drop_message(): + return Messages(Message.to_drop()) + + def fetch_data(self, payload: TrainerPayload) -> pd.DataFrame: + """Fetch the data from the data source.""" + fetcher_conf = self.stream_conf.druid_fetcher + data_fetcher = DruidFetcher(url=self.druid_conf.url, endpoint=self.druid_conf.endpoint) + + return data_fetcher.fetch_data( + datasource=fetcher_conf.datasource, + filter_keys=self.stream_conf.composite_keys + fetcher_conf.dimensions, + filter_values=payload.composite_keys[1:] + payload.metrics, + dimensions=OmegaConf.to_container(fetcher_conf.dimensions), + granularity=fetcher_conf.granularity, + aggregations=OmegaConf.to_container(fetcher_conf.aggregations), + group_by=OmegaConf.to_container(fetcher_conf.group_by), + pivot=fetcher_conf.pivot, + hours=fetcher_conf.hours, + ) + + def _is_new_request(self, payload: TrainerPayload) -> bool: + _ckeys = ":".join(*payload.composite_keys, *payload.metrics) + r_key = f"train::{_ckeys}" + value = self._rclient.get(r_key) + if value: + return False + + self._rclient.setex(r_key, time=REQUEST_EXPIRY, value=1) + return True + + def exec(self, keys: list[str], datum: Datum) -> Messages: + """Execute the UDF.""" + payload = TrainerPayload(**orjson.loads(datum.value)) + is_new = self._is_new_request(payload) + + if not is_new: + return self._drop_message() + + try: + train_df = self.fetch_data(payload) + except Exception as err: + _LOGGER.exception( + "%s - Error while fetching data for keys: %s, metrics: %s, err: %r", + payload.uuid, + payload.composite_keys, + payload.metrics, + err, + ) + return self._drop_message() + + if len(train_df) < MIN_RECORDS: + _LOGGER.warning( + "%s - Insufficient data for training. Expected: %d, Actual: %d", + payload.uuid, + MIN_RECORDS, + len(train_df), + ) + return self._drop_message() + + block_pl = BlockPipeline(*self._blocks_args, registry=self.model_registry) + block_pl.fit(train_df.to_numpy()) + block_pl.save(skeys=payload.composite_keys, dkeys=self._dkeys) + return self._drop_message() diff --git a/numalogic/numaflow/udf/blockinfer.py b/numalogic/numaflow/udf/blockinfer.py new file mode 100644 index 00000000..bbe5263d --- /dev/null +++ b/numalogic/numaflow/udf/blockinfer.py @@ -0,0 +1,144 @@ +import logging +from typing import Optional +from uuid import uuid4 + +import numpy as np +import numpy.typing as npt +from orjson import orjson +from pynumaflow.function import Datum, Messages, Message +from sklearn.preprocessing import StandardScaler + +from numalogic.blocks import ( + BlockPipeline, + PreprocessBlock, + NNBlock, + ThresholdBlock, + PostprocessBlock, +) +from numalogic.config import NumalogicConf, RedisConf, DataStreamConf +from numalogic.models.autoencoder.variants import SparseVanillaAE +from numalogic.models.threshold import StdDevThreshold +from numalogic.numaflow import NumalogicUDF +from numalogic.numaflow.entities import OutputPayload, TrainerPayload +from numalogic.registry import RedisRegistry +from numalogic.registry.redis_registry import get_redis_client_from_conf +from numalogic.tools.exceptions import RedisRegistryError, ModelKeyNotFound, ConfigError +from numalogic.transforms import TanhNorm + +_LOGGER = logging.getLogger(__name__) +TRAIN_VTX_KEY = "train" + + +class InferenceBlockUDF(NumalogicUDF): + """UDF to preprocess the input data for ML inference.""" + + def __init__( + self, numalogic_conf: NumalogicConf, redis_conf: RedisConf, stream_conf: DataStreamConf + ): + super().__init__() + self.conf = numalogic_conf + self.stream_conf = stream_conf + self.model_registry = RedisRegistry(client=self._get_redis_client(redis_conf)) + self._blocks_args = ( + PreprocessBlock(StandardScaler()), + NNBlock( + SparseVanillaAE( + seq_len=self.stream_conf.window_size, n_features=len(self.stream_conf.metrics) + ), + self.stream_conf.window_size, + ), + ThresholdBlock(StdDevThreshold()), + PostprocessBlock(TanhNorm()), + ) + self._dkeys = ("stdscaler", "sparsevanillae", "stddevthreshold") + + @staticmethod + def _get_redis_client(redis_conf: RedisConf): + return get_redis_client_from_conf(redis_conf) + + @staticmethod + def read_datum(datum: Datum, uuid: str = "") -> dict: + """Read the input datum and return the data payload.""" + try: + data_payload = orjson.loads(datum.value) + except orjson.JSONDecodeError as e: + _LOGGER.exception("%s - Error while reading input json %r", uuid, e) + return {} + _LOGGER.info("%s - Data payload: %s", uuid, data_payload) + return data_payload + + def extract_data(self, input_data: dict, uuid: str = "") -> Optional[npt.NDArray]: + """Extract the data from the input payload.""" + stream_array = [] + metrics = self.stream_conf.metrics + if not metrics: + raise ConfigError("No metrics found in stream config") + for row in input_data["data"]: + stream_array.append([row[metric] for metric in self.stream_conf.metrics]) + if not stream_array: + _LOGGER.warning("%s - No data found in input payload", uuid) + return None + return np.array(stream_array) + + def _construct_output( + self, anomaly_scores: npt.NDArray, input_payload: dict, model_version: str + ) -> OutputPayload: + """Construct the output payload.""" + unified_score = np.max(anomaly_scores) + metric_data = { + metric: { + "anomaly_score": anomaly_scores[idx], + } + for idx, metric in enumerate(self.stream_conf.metrics) + } + metric_data["model_version"] = model_version + + return OutputPayload( + timestamp=input_payload["end_time"], + unified_anomaly=unified_score, + data=metric_data, + metadata=input_payload.get("metadata", {}), + ) + + def _send_training_req(self, keys: list[str], uuid: str = "") -> Messages: + """Send a training request to the training UDF.""" + train_payload = TrainerPayload( + uuid=uuid, composite_keys=keys, metrics=self.stream_conf.metrics + ) + return Messages(Message(keys=keys, value=train_payload.to_json(), tags=[TRAIN_VTX_KEY])) + + def exec(self, keys: list[str], datum: Datum) -> Messages: + """Runs the UDF.""" + _uuid = uuid4().hex + input_payload = self.read_datum(datum, uuid=_uuid) + if not input_payload: + return Messages(Message.to_drop()) + + block_pl = BlockPipeline(*self._blocks_args, registry=self.model_registry) + stream_array = self.extract_data(input_payload, uuid=_uuid) + if stream_array is None: + return Messages(Message.to_drop()) + + # Load the model from the registry + try: + artifact_data = block_pl.load(skeys=keys, dkeys=self._dkeys) + except ModelKeyNotFound: + _LOGGER.info("%s - Model not found for skeys: %s, dkeys: %s", _uuid, keys, self._dkeys) + return self._send_training_req(keys, _uuid) + except RedisRegistryError as redis_err: + _LOGGER.exception("%s - Error loading block pipeline: %r", _uuid, redis_err) + return Messages(Message.to_drop()) + + # Run inference + try: + raw_scores = block_pl(stream_array) + except Exception as err: + _LOGGER.error("%s - Error running block pipeline: %r", _uuid, err) + return Messages(Message.to_drop()) + + # Find final anomaly score + anomaly_scores = np.mean(raw_scores, axis=0) + out_payload = self._construct_output( + anomaly_scores, input_payload, artifact_data.extras.get("version") + ) + return Messages(Message(keys=keys, value=out_payload.to_json())) diff --git a/tests/numaflow/__init__.py b/tests/numaflow/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/numaflow/test_blockinfer.py b/tests/numaflow/test_blockinfer.py new file mode 100644 index 00000000..e392e936 --- /dev/null +++ b/tests/numaflow/test_blockinfer.py @@ -0,0 +1,106 @@ +import logging +import unittest +from datetime import datetime +from unittest.mock import patch, Mock + +import numpy as np +from fakeredis import FakeStrictRedis, FakeServer +from orjson import orjson +from pynumaflow.function import Datum, DatumMetadata +from sklearn.preprocessing import StandardScaler + +from numalogic.blocks import ( + BlockPipeline, + PreprocessBlock, + NNBlock, + ThresholdBlock, + PostprocessBlock, +) +from numalogic.config import NumalogicConf, RedisConf, DataStreamConf +from numalogic.models.autoencoder.variants import SparseVanillaAE +from numalogic.models.threshold import StdDevThreshold +from numalogic.numaflow.entities import TrainerPayload, OutputPayload +from numalogic.numaflow.udf import InferenceBlockUDF +from numalogic.registry import RedisRegistry +from numalogic.transforms import TanhNorm + +PAYLOAD = { + "data": [ + {"degraded": 0, "failed": 0, "success": 14, "timestamp": 1689189540000}, + {"degraded": 0, "failed": 0, "success": 8, "timestamp": 1689189600000}, + {"degraded": 0, "failed": 0, "success": 6, "timestamp": 1689189660000}, + {"degraded": 0, "failed": 4, "success": 2, "timestamp": 1689189720000}, + {"degraded": 0, "failed": 0, "success": 2, "timestamp": 1689189780000}, + {"degraded": 2, "failed": 0, "success": 4, "timestamp": 1689189840000}, + {"degraded": 0, "failed": 0, "success": 6, "timestamp": 1689190080000}, + {"degraded": 0, "failed": 0, "success": 6, "timestamp": 1689190140000}, + {"degraded": 1, "failed": 0, "success": 2, "timestamp": 1689190500000}, + {"degraded": 0, "failed": 0, "success": 2, "timestamp": 1689190560000}, + {"degraded": 0, "failed": 0, "success": 2, "timestamp": 1689190620000}, + {"degraded": 0, "failed": 0, "success": 2, "timestamp": 1689190680000}, + ], + "start_time": 1689189540000, + "end_time": 1689190740000, + "metadata": { + "current_pipeline_name": "PluginAsset", + "druid": {"source": "numalogic"}, + "wavefront": {"source": "numalogic"}, + }, +} +MOCK_REDIS_CLIENT = FakeStrictRedis(server=FakeServer()) +logging.basicConfig(level=logging.DEBUG) + + +class TestInferenceBlockUDF(unittest.TestCase): + def setUp(self) -> None: + self.skeys = ["PluginAsset", "1804835142986662399"] + self.dkeys = ["stdscaler", "sparsevanillae", "stddevthreshold"] + self.datum = Datum( + keys=self.skeys, + value=orjson.dumps(PAYLOAD), + event_time=datetime.now(), + watermark=datetime.now(), + metadata=DatumMetadata(msg_id="", num_delivered=0), + ) + + def tearDown(self) -> None: + MOCK_REDIS_CLIENT.flushall() + + def train_block(self) -> None: + block_pl = BlockPipeline( + PreprocessBlock(StandardScaler()), + NNBlock(SparseVanillaAE(seq_len=12, n_features=3), 12), + ThresholdBlock(StdDevThreshold()), + PostprocessBlock(TanhNorm()), + registry=RedisRegistry(client=MOCK_REDIS_CLIENT), + ) + rng = np.random.default_rng() + block_pl.fit(rng.random((1000, 3)), nn__max_epochs=5) + block_pl.save(skeys=self.skeys, dkeys=self.dkeys) + + @patch.object(InferenceBlockUDF, "_get_redis_client", Mock(return_value=MOCK_REDIS_CLIENT)) + def test_exec_success(self) -> None: + self.train_block() + udf = InferenceBlockUDF( + NumalogicConf(), + RedisConf(host="localhost", port=6379), + DataStreamConf(metrics=["degraded", "failed", "success"]), + ) + out_msgs = udf(self.skeys, self.datum) + self.assertEqual(1, len(out_msgs)) + self.assertIsInstance(OutputPayload(**orjson.loads(out_msgs[0].value)), OutputPayload) + + @patch.object(InferenceBlockUDF, "_get_redis_client", Mock(return_value=MOCK_REDIS_CLIENT)) + def test_exec_no_model(self) -> None: + udf = InferenceBlockUDF( + NumalogicConf(), + RedisConf(host="localhost", port=6379), + DataStreamConf(metrics=["degraded", "failed", "success"]), + ) + out_msgs = udf(self.skeys, self.datum) + self.assertEqual(1, len(out_msgs)) + self.assertIsInstance(TrainerPayload(**orjson.loads(out_msgs[0].value)), TrainerPayload) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_numaflow.py b/tests/numaflow/test_numaflow.py similarity index 100% rename from tests/test_numaflow.py rename to tests/numaflow/test_numaflow.py