diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 457be12f..77e83e74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,3 +27,7 @@ repos: - id: check-ast - id: check-case-conflict - id: check-docstring-first +- repo: https://github.com/python-poetry/poetry + rev: "1.6" + hooks: + - id: poetry-check diff --git a/Dockerfile b/Dockerfile index 8e461b13..bc83b9fa 100644 --- a/Dockerfile +++ b/Dockerfile @@ -41,7 +41,7 @@ WORKDIR $PYSETUP_PATH COPY ./pyproject.toml ./poetry.lock ./ # TODO install cpu/gpu based on args/arch -RUN poetry install --without dev --no-cache --no-root -E numaflow --extras "${INSTALL_EXTRAS}" && \ +RUN poetry install --without dev --no-cache --no-root --extras "${INSTALL_EXTRAS}" && \ poetry run pip install --no-cache "torch>=2.0,<3.0" --index-url https://download.pytorch.org/whl/cpu && \ poetry run pip install --no-cache "pytorch-lightning>=2.0<3.0" && \ rm -rf ~/.cache/pypoetry/ @@ -51,4 +51,4 @@ WORKDIR /app ENTRYPOINT ["/usr/bin/dumb-init", "--"] -EXPOSE 5000 \ No newline at end of file +EXPOSE 5000 diff --git a/Makefile b/Makefile index ad3fc65f..5ddb4b6a 100644 --- a/Makefile +++ b/Makefile @@ -23,7 +23,7 @@ lint: format # install all dependencies setup: - poetry install --with dev,torch --all-extras --no-root + poetry install --with dev,torch --all-extras # test your application (tests in the tests/ directory) test: diff --git a/numalogic/backtest/__init__.py b/numalogic/backtest/__init__.py new file mode 100644 index 00000000..b6b7ef56 --- /dev/null +++ b/numalogic/backtest/__init__.py @@ -0,0 +1,11 @@ +from importlib.util import find_spec + + +def _validate_req_pkgs(): + if (not find_spec("torch")) or (not find_spec("pytorch_lightning")): + raise ModuleNotFoundError( + "Pytorch and/or Pytorch lightning is not installed. Please install them first." + ) + + +_validate_req_pkgs() diff --git a/numalogic/backtest/__main__.py b/numalogic/backtest/__main__.py new file mode 100644 index 00000000..bdbe508b --- /dev/null +++ b/numalogic/backtest/__main__.py @@ -0,0 +1,102 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import shutil +from pathlib import Path +from typing import Annotated +from typing import Optional + +import typer + +import numalogic.backtest._bt as bt +from numalogic.backtest._clifunc import clear_outputs, train_models, generate_scores + +logging.basicConfig(level=logging.INFO) + + +app = typer.Typer() +app.add_typer(bt.app, name="backtest") + + +@app.command() +def clear( + appname: Optional[str] = None, + metric: Optional[str] = None, + output_dir: Annotated[Optional[Path], typer.Option()] = None, + all_: Annotated[bool, typer.Option("--all")] = False, +): + """CLI entrypoint for clearing the backtest output files.""" + if not output_dir: + output_dir = os.path.join(os.getcwd(), ".btoutput") + + if all_: + print(f"Clearing all the backtest output files in {output_dir}") + try: + shutil.rmtree(output_dir, ignore_errors=False, onerror=None) + except FileNotFoundError: + pass + return + + if not (appname and metric): + _msg = "Both appname and metric needs to be provided!" + print(_msg) + return + + clear_outputs(appname=appname, metric=metric, output_dir=output_dir) + + +@app.command() +def train( + data_file: Annotated[Optional[Path], typer.Option()] = None, + col_name: Annotated[Optional[str], typer.Option()] = None, + ts_col_name: Annotated[str, typer.Option()] = "timestamp", + train_ratio: Annotated[float, typer.Option()] = 0.9, + output_dir: Annotated[Optional[Path], typer.Option()] = None, +): + """CLI entrypoint for training models for the given data.""" + if (data_file is None) or (col_name is None): + print("No data file or column name provided!") + raise typer.Abort() + + if not output_dir: + output_dir = os.path.join(os.getcwd(), ".btoutput") + + train_models( + data_file=data_file, + col_name=col_name, + ts_col_name=ts_col_name, + train_ratio=train_ratio, + output_dir=output_dir, + ) + + +@app.command() +def score( + data_file: Annotated[Optional[Path], typer.Option()] = None, + col_name: Annotated[Optional[str], typer.Option()] = None, + ts_col_name: Annotated[str, typer.Option()] = "timestamp", + model_path: Annotated[Optional[Path], typer.Option()] = None, + test_ratio: Annotated[float, typer.Option()] = 1.0, +): + """CLI entrypoint for generating scores for the given data.""" + if (data_file is None) or (col_name is None): + print("No data file or column name provided!") + raise typer.Abort() + + generate_scores( + data_file=data_file, + col_name=col_name, + ts_col_name=ts_col_name, + model_path=model_path, + test_ratio=test_ratio, + ) diff --git a/numalogic/backtest/_bt.py b/numalogic/backtest/_bt.py new file mode 100644 index 00000000..988a22a9 --- /dev/null +++ b/numalogic/backtest/_bt.py @@ -0,0 +1,50 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Annotated, Optional + +import typer + +from numalogic.backtest._clifunc import univar_backtest, multivar_backtest +from numalogic.backtest._constants import DEFAULT_PROM_LOCALHOST + +app = typer.Typer() + + +@app.command() +def univariate( + namespace: Annotated[str, typer.Argument(help="Namespace name")], + appname: Annotated[str, typer.Argument(help="Application name")], + metric: Annotated[str, typer.Argument(help="The timeseries metric to analyze")], + url: Annotated[ + str, typer.Option(envvar="PROM_URL", help="Endpoint URL for datafetching") + ] = DEFAULT_PROM_LOCALHOST, + lookback_days: Annotated[int, typer.Option(help="Number of days of data to fetch")] = 8, + output_dir: Annotated[Optional[str], typer.Option(help="Output directory")] = None, +): + """CLI entry point for backtest run for a single metric.""" + if not output_dir: + output_dir = os.path.join(os.getcwd(), ".btoutput") + univar_backtest( + namespace=namespace, + appname=appname, + metric=metric, + url=url, + lookback_days=lookback_days, + output_dir=output_dir, + ) + + +@app.command() +def multivariate(): + """CLI entry point for backtest run for multiple metrics in a multivariate fashion.""" + multivar_backtest() diff --git a/numalogic/backtest/_clifunc.py b/numalogic/backtest/_clifunc.py new file mode 100644 index 00000000..9cb0301a --- /dev/null +++ b/numalogic/backtest/_clifunc.py @@ -0,0 +1,90 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import shutil +from pathlib import Path +from typing import Union + +import pandas as pd + +from numalogic.backtest._prom import PromUnivarBacktester +from numalogic.tools.exceptions import DataFormatError + + +logger = logging.getLogger(__name__) + + +def univar_backtest( + namespace: str, appname: str, metric: str, url: str, lookback_days: int, output_dir: str +): + """Run backtest for a single metric.""" + backtester = PromUnivarBacktester( + url, namespace, appname, metric, lookback_days=lookback_days, output_dir=output_dir + ) + df = backtester.read_data() + backtester.train_models(df) + out_df = backtester.generate_scores(df) + backtester.save_plots(out_df) + + +def multivar_backtest(*_, **__): + """Run backtest for multiple metrics in a multivariate fashion.""" + raise NotImplementedError + + +def clear_outputs(appname: str, metric: str, output_dir: str) -> None: + """Clear the backtest output files.""" + _dir = PromUnivarBacktester.get_outdir(appname, metric, outdir=output_dir) + logger.info("Clearing backtest output files in %s", _dir) + shutil.rmtree(_dir, ignore_errors=False, onerror=None) + + +def train_models( + data_file: Union[Path, str], + col_name: str, + ts_col_name: str, + train_ratio: float, + output_dir: Union[Path, str], +): + """Train models for the given data.""" + backtester = PromUnivarBacktester( + "", "", "", col_name, test_ratio=(1 - train_ratio), output_dir=output_dir + ) + + df = pd.read_csv(data_file) + try: + df.set_index([ts_col_name], inplace=True) + except KeyError: + raise DataFormatError(f"Timestamp column {ts_col_name} not found in the data!") from None + + df.index = pd.to_datetime(df.index) + backtester.train_models(df) + + +def generate_scores( + data_file: Union[Path, str], + col_name: str, + ts_col_name: str, + model_path: Union[Path, str], + test_ratio: float, +): + """Generate scores for the given data.""" + backtester = PromUnivarBacktester("", "", "", col_name, test_ratio=test_ratio) + + df = pd.read_csv(data_file) + try: + df.set_index([ts_col_name], inplace=True) + except KeyError: + raise DataFormatError(f"Timestamp column {ts_col_name} not found in the data!") from None + + df.index = pd.to_datetime(df.index) + backtester.generate_scores(df, model_path=model_path) diff --git a/numalogic/backtest/_constants.py b/numalogic/backtest/_constants.py new file mode 100644 index 00000000..ed6738db --- /dev/null +++ b/numalogic/backtest/_constants.py @@ -0,0 +1,8 @@ +import os +from typing import Final + +from numalogic._constants import BASE_DIR + +DEFAULT_OUTPUT_DIR: Final[str] = os.path.join(BASE_DIR, ".btoutput") +DEFAULT_SEQUENCE_LEN: Final[int] = 12 +DEFAULT_PROM_LOCALHOST: Final[str] = "http://localhost:9090/" diff --git a/numalogic/backtest/_prom.py b/numalogic/backtest/_prom.py new file mode 100644 index 00000000..016f637d --- /dev/null +++ b/numalogic/backtest/_prom.py @@ -0,0 +1,335 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os.path +from datetime import datetime, timedelta +from pathlib import Path +from typing import Optional, Union + +import numpy as np +import pandas as pd +import torch +from matplotlib import pyplot as plt +from numpy.typing import NDArray + +from numalogic._constants import BASE_DIR +from numalogic.backtest._constants import DEFAULT_SEQUENCE_LEN +from numalogic.config import ( + TrainerConf, + LightningTrainerConf, + NumalogicConf, + ModelInfo, + ModelFactory, + PreprocessFactory, + PostprocessFactory, + ThresholdFactory, +) +from numalogic.connectors import ConnectorType +from numalogic.connectors.prometheus import PrometheusFetcher +from numalogic.tools.data import StreamingDataset, inverse_window +from numalogic.tools.types import artifact_t +from numalogic.udfs import UDFFactory, StreamConf + +DEFAULT_OUTPUT_DIR = os.path.join(BASE_DIR, ".btoutput") +LOGGER = logging.getLogger(__name__) + + +def _init_default_streamconf(metrics: list[str]) -> StreamConf: + numalogic_cfg = NumalogicConf( + model=ModelInfo( + "VanillaAE", conf={"seq_len": DEFAULT_SEQUENCE_LEN, "n_features": len(metrics)} + ), + preprocess=[ModelInfo("StandardScaler")], + trainer=TrainerConf(pltrainer_conf=LightningTrainerConf(accelerator="cpu")), + ) + return StreamConf( + source=ConnectorType.prometheus, + window_size=DEFAULT_SEQUENCE_LEN, + metrics=metrics, + numalogic_conf=numalogic_cfg, + ) + + +class PromUnivarBacktester: + """ + Class for running backtest for a single metric on data from Prometheus or Thanos. + + Args: + url: Prometheus/Thanos URL + namespace: Namespace of the metric + appname: Application name + metric: Metric name + return_labels: Prometheus label names as columns to return + lookback_days: Number of days of data to fetch + output_dir: Output directory + test_ratio: Ratio of test data to total data + stream_conf: Stream configuration + """ + + def __init__( + self, + url: str, + namespace: str, + appname: str, + metric: str, + return_labels: Optional[list[str]] = None, + lookback_days: int = 8, + output_dir: Union[str, Path] = DEFAULT_OUTPUT_DIR, + test_ratio: float = 0.25, + stream_conf: Optional[StreamConf] = None, + ): + self._url = url + self.namespace = namespace + self.appname = appname + self.metric = metric + self.conf = stream_conf or _init_default_streamconf([metric]) + self.test_ratio = test_ratio + self.lookback_days = lookback_days + self.return_labels = return_labels + + self._seq_len = self.conf.window_size + self._n_features = len(self.conf.metrics) + + self.out_dir = self.get_outdir(appname, metric, outdir=output_dir) + self._datapath = os.path.join(self.out_dir, "data.csv") + self._modelpath = os.path.join(self.out_dir, "models.pt") + self._outpath = os.path.join(self.out_dir, "output.csv") + + if not os.path.exists(self.out_dir): + os.makedirs(self.out_dir) + + @classmethod + def get_outdir(cls, appname: str, metric: str, outdir=DEFAULT_OUTPUT_DIR) -> str: + """Get the output directory for the given metric.""" + if not appname: + return os.path.join(outdir, metric) + _key = ":".join([appname, metric]) + return os.path.join(outdir, _key) + + def read_data(self, fill_na_value: float = 0.0) -> pd.DataFrame: + """ + Reads data from Prometheus/Thanos and returns a dataframe. + + Args: + fill_na_value: Value to fill NaNs with + + Returns + ------- + Dataframe with timestamp and metric values + """ + datafetcher = PrometheusFetcher(self._url) + df = datafetcher.fetch( + metric_name=self.metric, + start=(datetime.now() - timedelta(days=self.lookback_days)), + end=datetime.now(), + filters={"namespace": self.namespace, "app": self.appname}, + return_labels=self.return_labels, + aggregate=False, + ) + LOGGER.info( + "Fetched dataframe with lookback days: %s with shape: %s", self.lookback_days, df.shape + ) + + df.set_index(["timestamp"], inplace=True) + df.index = pd.to_datetime(df.index) + + df.replace([np.inf, -np.inf], np.nan, inplace=True) + df = df.fillna(fill_na_value) + + df.to_csv(self._datapath, index=True) + return df + + def train_models( + self, + df: Optional[pd.DataFrame] = None, + ) -> dict[str, artifact_t]: + """ + Train models for the given data. + + Args: + df: Dataframe with timestamp and metric values + + Returns + ------- + Dictionary of trained models + """ + if df is None: + df = self._read_or_fetch_data() + + df_train, _ = self._split_data(df[[self.metric]]) + x_train = df_train.to_numpy(dtype=np.float32) + LOGGER.info("Training data shape: %s", x_train.shape) + + artifacts = UDFFactory.get_udf_cls("trainer").compute( + model=ModelFactory().get_instance(self.conf.numalogic_conf.model), + input_=x_train, + preproc_clf=PreprocessFactory().get_pipeline_instance( + self.conf.numalogic_conf.preprocess + ), + threshold_clf=ThresholdFactory().get_instance(self.conf.numalogic_conf.threshold), + trainer_cfg=self.conf.numalogic_conf.trainer, + ) + with open(self._modelpath, "wb") as f: + torch.save(artifacts, f) + LOGGER.info("Models saved in %s", self._modelpath) + return artifacts + + def generate_scores( + self, df: Optional[pd.DataFrame] = None, model_path: Optional[str] = None + ) -> pd.DataFrame: + """ + Generate scores for the given data. + + Args: + df: Dataframe with timestamp and metric values + model_path: Path to the saved models + + Returns + ------- + Dataframe with timestamp and metric values + """ + if df is None: + df = self._read_or_fetch_data() + + artifacts = self._load_or_train_model(df, model_path) + + _, df_test = self._split_data(df[[self.metric]]) + x_test = df_test.to_numpy(dtype=np.float32) + LOGGER.info("Test data shape: %s", df_test.shape) + + preproc_udf = UDFFactory.get_udf_cls("preprocess") + nn_udf = UDFFactory.get_udf_cls("inference") + postproc_udf = UDFFactory.get_udf_cls("postprocess") + + x_scaled = preproc_udf.compute(model=artifacts["preproc_clf"], input_=x_test) + + ds = StreamingDataset(x_scaled, seq_len=self.conf.window_size) + anomaly_scores = np.zeros( + (len(ds), self.conf.window_size, len(self.conf.metrics)), dtype=np.float32 + ) + x_recon = np.zeros_like(anomaly_scores, dtype=np.float32) + postproc_func = PostprocessFactory().get_instance(self.conf.numalogic_conf.postprocess) + + for idx, arr in enumerate(ds): + x_recon[idx] = nn_udf.compute(model=artifacts["model"], input_=arr) + anomaly_scores[idx] = postproc_udf.compute( + model=artifacts["threshold_clf"], + input_=x_recon[idx], + postproc_clf=postproc_func, + ) + x_recon = inverse_window(torch.from_numpy(x_recon)).numpy() + final_scores = np.mean(anomaly_scores, axis=1) + output_df = self._construct_output_df( + timestamps=df_test.index, + test_data=x_test, + preproc_out=x_scaled, + nn_out=x_recon, + postproc_out=final_scores, + ) + output_df.to_csv(self._outpath, index=True, index_label="timestamp") + LOGGER.info("Results saved in: %s", self._outpath) + return output_df + + def save_plots( + self, output_df: Optional[pd.DataFrame] = None, plotname: str = "plot.png" + ) -> None: + """ + Save plots for the given data. + + Args: + output_df: Dataframe with timestamp, and anomaly scores + plotname: Name of the plot file + """ + if output_df is None: + output_df = pd.read_csv(self._outpath, index_col="timestamp", parse_dates=True) + + fig, axs = plt.subplots(4, 1, sharex="col", figsize=(15, 8)) + + axs[0].plot(output_df["metric"], color="b") + axs[0].set_ylabel("Original metric") + axs[0].grid(True) + axs[0].set_title( + f"TEST SET RESULTS\nMetric: {self.metric}\n" + f"namespace: {self.namespace}\napp: {self.appname}" + ) + + axs[1].plot(output_df["preprocessed"], color="g") + axs[1].grid(True) + axs[1].set_ylabel("Preprocessed metric") + + axs[2].plot(output_df["model_out"], color="black") + axs[2].grid(True) + axs[2].set_ylabel("NN model output") + + axs[3].plot(output_df["scores"], color="r") + axs[3].grid(True) + axs[3].set_ylabel("Anomaly Score") + axs[3].set_xlabel("Time") + axs[3].set_ylim(0, 10) + + fig.tight_layout() + _fname = os.path.join(self.out_dir, plotname) + fig.savefig(_fname) + LOGGER.info("Plot file: %s saved in %s", plotname, self.out_dir) + + def _split_data(self, df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]: + test_size = int(df.shape[0] * self.test_ratio) + return df.iloc[:-test_size], df.iloc[-test_size:] + + def _read_or_fetch_data(self) -> pd.DataFrame: + try: + df = pd.read_csv(self._datapath, dtype=np.float32) + except FileNotFoundError: + LOGGER.info("No saved data found! Fetching data...") + df = self.read_data() + else: + LOGGER.info("Saved data found! Reading from %s", self._datapath) + df.set_index(["timestamp"], inplace=True) + df.index = pd.to_datetime(df.index) + return df + + def _load_or_train_model(self, df: pd.DataFrame, model_path: str) -> dict[str, artifact_t]: + _modelpath = model_path or self._modelpath + try: + with open(_modelpath, "rb") as f: + artifacts = torch.load(f) + except FileNotFoundError: + LOGGER.info("No saved models found! Training models...") + artifacts = self.train_models(df) + else: + LOGGER.info("Loaded models from %s", _modelpath) + return artifacts + + def _construct_output_df( + self, + timestamps: pd.Index, + test_data: NDArray[float], + preproc_out: NDArray[float], + nn_out: NDArray[float], + postproc_out: NDArray[float], + ) -> pd.DataFrame: + scores = np.vstack( + [ + np.full((self._seq_len - 1, self._n_features), fill_value=np.nan), + postproc_out, + ] + ) + + return pd.DataFrame( + { + "metric": test_data.squeeze(), + "preprocessed": preproc_out.squeeze(), + "model_out": nn_out.squeeze(), + "scores": scores.squeeze(), + }, + index=timestamps, + ) diff --git a/numalogic/config/_config.py b/numalogic/config/_config.py index cd33d484..862c95f5 100644 --- a/numalogic/config/_config.py +++ b/numalogic/config/_config.py @@ -75,7 +75,7 @@ class TrainerConf: retrain_freq_hr: int = 24 model_expiry_sec: int = 86400 # 24 hrs # TODO: revisit this dedup_expiry_sec: int = 1800 # 30 days # TODO: revisit this - batch_size: int = 32 + batch_size: int = 64 pltrainer_conf: LightningTrainerConf = field(default_factory=LightningTrainerConf) diff --git a/numalogic/config/factory.py b/numalogic/config/factory.py index 8b039ec8..a82ca89b 100644 --- a/numalogic/config/factory.py +++ b/numalogic/config/factory.py @@ -8,12 +8,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, ClassVar +from typing import Union, ClassVar, TypeVar + +from sklearn.pipeline import make_pipeline from numalogic.config._config import ModelInfo, RegistryInfo from numalogic.tools.exceptions import UnknownConfigArgsError +conf_t = TypeVar("conf_t", bound=Union[ModelInfo, RegistryInfo], covariant=True) + + class _ObjectFactory: _CLS_MAP: ClassVar[dict] = {} @@ -26,13 +31,12 @@ def get_instance(self, object_info: Union[ModelInfo, RegistryInfo]): ) from err return _cls(**object_info.conf) - def get_cls(self, object_info: Union[ModelInfo, RegistryInfo]): + @classmethod + def get_cls(cls, name: str): try: - return self._CLS_MAP[object_info.name] - except KeyError as err: - raise UnknownConfigArgsError( - f"Invalid model info instance provided: {object_info}" - ) from err + return cls._CLS_MAP[name] + except KeyError: + raise UnknownConfigArgsError(f"Invalid name provided for factory: {name}") from None class PreprocessFactory(_ObjectFactory): @@ -51,6 +55,17 @@ class PreprocessFactory(_ObjectFactory): "TanhScaler": TanhScaler, } + def get_pipeline_instance(self, objs_info: list[ModelInfo]): + preproc_clfs = [] + for obj_info in objs_info: + _clf = self.get_instance(obj_info) + preproc_clfs.append(_clf) + if not preproc_clfs: + return None + if len(preproc_clfs) == 1: + return preproc_clfs[0] + return make_pipeline(*preproc_clfs) + class PostprocessFactory(_ObjectFactory): """Factory class to create postprocess instances.""" @@ -107,7 +122,7 @@ class ModelFactory(_ObjectFactory): class RegistryFactory(_ObjectFactory): """Factory class to create registry instances.""" - _CLS_SET: ClassVar[frozenset] = {"RedisRegistry", "MLflowRegistry"} + _CLS_SET: ClassVar[frozenset] = frozenset({"RedisRegistry", "MLflowRegistry"}) def get_instance(self, object_info: Union[ModelInfo, RegistryInfo]): import numalogic.registry as reg @@ -124,16 +139,38 @@ def get_instance(self, object_info: Union[ModelInfo, RegistryInfo]): ) from err return _cls(**object_info.conf) - def get_cls(self, object_info: Union[ModelInfo, RegistryInfo]): + @classmethod + def get_cls(cls, name: str): import numalogic.registry as reg try: - return getattr(reg, object_info.name) + return getattr(reg, name) except AttributeError as err: - if object_info.name in self._CLS_SET: + if name in cls._CLS_SET: raise ImportError( "Please install the required dependencies for the registry you want to use." ) from err raise UnknownConfigArgsError( - f"Invalid model info instance provided: {object_info}" - ) from err + f"Invalid name provided for RegistryFactory: {name}" + ) from None + + +class ConnectorFactory(_ObjectFactory): + """Factory class for data connectors.""" + + _CLS_SET: ClassVar[frozenset] = frozenset({"PrometheusFetcher", "DruidFetcher"}) + + @classmethod + def get_cls(cls, name: str): + import numalogic.connectors as conn + + try: + return getattr(conn, name) + except AttributeError as err: + if name in cls._CLS_SET: + raise ImportError( + "Please install the required dependencies for the connector you want to use." + ) from err + raise UnknownConfigArgsError( + f"Invalid name provided for ConnectorFactory: {name}" + ) from None diff --git a/numalogic/connectors/__init__.py b/numalogic/connectors/__init__.py index 5b44f069..26f93267 100644 --- a/numalogic/connectors/__init__.py +++ b/numalogic/connectors/__init__.py @@ -1,3 +1,5 @@ +from importlib.util import find_spec + from numalogic.connectors._config import ( RedisConf, PrometheusConf, @@ -6,6 +8,7 @@ DruidFetcherConf, ConnectorType, ) +from numalogic.connectors.prometheus import PrometheusFetcher __all__ = [ "RedisConf", @@ -14,4 +17,14 @@ "DruidConf", "DruidFetcherConf", "ConnectorType", + "PrometheusFetcher", ] + +try: + find_spec("pydruid") +except (ImportError, ModuleNotFoundError): + pass +else: + from numalogic.connectors.druid import DruidFetcher # noqa: F401 + + __all__.append("DruidFetcher") diff --git a/numalogic/connectors/_base.py b/numalogic/connectors/_base.py new file mode 100644 index 00000000..1c549556 --- /dev/null +++ b/numalogic/connectors/_base.py @@ -0,0 +1,33 @@ +from abc import ABCMeta, abstractmethod + +import pandas as pd + + +class DataFetcher(metaclass=ABCMeta): + __slots__ = ("url",) + + def __init__(self, url: str): + self.url = url + + @abstractmethod + def fetch(self, *args, **kwargs) -> pd.DataFrame: + pass + + @abstractmethod + def raw_fetch(self, *args, **kwargs) -> pd.DataFrame: + pass + + +class AsyncDataFetcher(metaclass=ABCMeta): + __slots__ = ("url",) + + def __init__(self, url: str): + self.url = url + + @abstractmethod + async def fetch_data(self, *args, **kwargs) -> pd.DataFrame: + pass + + @abstractmethod + async def raw_fetch(self, *args, **kwargs) -> pd.DataFrame: + pass diff --git a/numalogic/connectors/druid.py b/numalogic/connectors/druid.py index 7bf3ea1d..5a528138 100644 --- a/numalogic/connectors/druid.py +++ b/numalogic/connectors/druid.py @@ -5,14 +5,17 @@ import pandas as pd import pytz from pydruid.client import PyDruid +from pydruid.utils.dimensions import DimensionSpec from pydruid.utils.filters import Filter + +from numalogic.connectors._base import DataFetcher from numalogic.connectors._config import Pivot from typing import Optional _LOGGER = logging.getLogger(__name__) -class DruidFetcher: +class DruidFetcher(DataFetcher): """ Class for fetching data as a dataframe from Druid. @@ -22,9 +25,10 @@ class DruidFetcher: """ def __init__(self, url: str, endpoint: str): + super().__init__(url) self.client = PyDruid(url, endpoint) - def fetch_data( + def fetch( self, datasource: str, filter_keys: list[str], @@ -50,20 +54,22 @@ def fetch_data( start_dt = end_dt - timedelta(hours=hours) intervals = [f"{start_dt.isoformat()}/{end_dt.isoformat()}"] + dimension_specs = map(lambda d: DimensionSpec(dimension=d, output_name=d), dimensions) + params = { "datasource": datasource, "granularity": granularity, "intervals": intervals, "aggregations": aggregations, "filter": _filter, - "dimensions": dimensions, + "dimensions": dimension_specs, } _LOGGER.debug( "Fetching data with params: %s", params, ) - + self.client.sub_query(**params) response = self.client.groupby(**params) df = response.export_pandas() @@ -85,3 +91,6 @@ def fetch_data( df.columns = df.columns.map("{0[1]}".format) df.reset_index(inplace=True) return df + + def raw_fetch(self, *args, **kwargs) -> pd.DataFrame: + raise NotImplementedError diff --git a/numalogic/connectors/prometheus.py b/numalogic/connectors/prometheus.py new file mode 100644 index 00000000..626c4723 --- /dev/null +++ b/numalogic/connectors/prometheus.py @@ -0,0 +1,260 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from functools import reduce +from operator import iconcat +from typing import Optional, Final + +import orjson +import pandas as pd +import requests +from orjson import JSONDecodeError + +from numalogic.connectors._base import DataFetcher +from numalogic.tools.exceptions import PrometheusFetcherError, PrometheusInvalidResponseError + +LOGGER = logging.getLogger(__name__) +MAX_DATA_POINTS: Final[int] = 11000 +_MAX_RECURSION_DEPTH: Final[int] = 10 +_API_ENDPOINT: Final[str] = "/api/v1/query_range" + + +class PrometheusFetcher(DataFetcher): + """ + Class for fetching data as a dataframe from Prometheus or Thanos. + + Args: + prometheus_server: Prometheus/Thanos URL + scrape_interval_secs: Prometheus scrape interval in seconds + """ + + __slots__ = ("_endpoint", "_step_secs") + + def __init__(self, prometheus_server: str, scrape_interval_secs: int = 30): + super().__init__(prometheus_server) + self._endpoint = f"{self.url}{_API_ENDPOINT}" + self._step_secs = scrape_interval_secs + + @staticmethod + def build_query(metric: str, filters: dict[str, str]) -> str: + """Builds a Prometheus query string from metric name and filters.""" + query = metric + if filters: + label_filters = ",".join(f"{k}='{v}'" for k, v in filters.items()) + return f"{metric}{{{label_filters}}}" + return query + + def fetch( + self, + metric_name: str, + start: datetime, + end: Optional[datetime] = None, + filters: Optional[dict[str, str]] = None, + return_labels: Optional[list[str]] = None, + aggregate: bool = True, + ) -> pd.DataFrame: + """ + Fetches data from Prometheus/Thanos and returns a dataframe. + + Args: + ------- + metric_name: Prometheus metric name + start: Start time + end: End time + filters: Prometheus label filters + return_labels: Prometheus label names as columns to return + aggregate: Whether to aggregate the data + + Returns + ------- + Dataframe with timestamp and metric values + + Raises + ------ + ValueError: If end time is before start time + PrometheusFetcherError: If there is an error while fetching data + PrometheusInvalidResponseError: If the response from Prometheus is invalid + RecursionError: If the recursive depth exceeds the max depth + """ + query = self.build_query(metric_name, filters) + + LOGGER.info("Constructed Prometheus Query: %s", query) + + end_ts, start_ts = self._init_startend_ts(end, start) + results = self.query_range(query, start_ts, end_ts) + + df = pd.json_normalize(results) + return_labels = [f"metric.{label}" for label in return_labels or []] + if df.empty: + LOGGER.warning("Query returned no results") + return df + + df = self._consolidate_df(df, metric_name, return_labels) + + if aggregate and return_labels: + df = self._agg_df(df, metric_name) + + df.sort_values(by=["timestamp"], inplace=True) + + return df + + def raw_fetch( + self, + query: str, + start: datetime, + end: Optional[datetime] = None, + return_labels: Optional[list[str]] = None, + aggregate: bool = True, + ): + """ + Fetches data from Prometheus/Thanos using the provided raw query and returns a dataframe. + + Args: + ------- + query: Raw prometheus query + start: Start time + end: End time + return_labels: Prometheus label names as columns to return + aggregate: Whether to aggregate the data + + Returns + ------- + Dataframe with timestamp and metric values + + Raises + ------ + ValueError: If end time is before start time + PrometheusFetcherError: If there is an error while fetching data + PrometheusInvalidResponseError: If the response from Prometheus is invalid + RecursionError: If the recursive depth exceeds the max depth + """ + end_ts, start_ts = self._init_startend_ts(end, start) + results = self.query_range(query, start_ts, end_ts) + + df = pd.json_normalize(results) + if df.empty: + LOGGER.warning("Query returned no results") + return df + + return_labels = [f"metric.{label}" for label in return_labels or []] + metric_name = self._extract_metric_name(df) or "metric" + + df = self._consolidate_df(df, metric_name, return_labels) + + if aggregate and return_labels: + df = self._agg_df(df, metric_name) + + df.sort_values(by=["timestamp"], inplace=True) + return df + + @staticmethod + def _agg_df(df, metric_name: str): + df = df.groupby(by=["timestamp"]).apply(lambda x: x[[metric_name]].mean()) + df.reset_index(inplace=True) + return df + + @staticmethod + def _consolidate_df(df: pd.DataFrame, metric_name: str, return_labels: list[str]): + df = df[["values", *return_labels]] + df = df.explode("values", ignore_index=True) + df[["timestamp", metric_name]] = df["values"].to_list() + df.drop(columns=["values"], inplace=True) + df = df.astype({metric_name: float}) + df["timestamp"] = pd.to_datetime(df["timestamp"], unit="s") + return df + + @staticmethod + def _init_startend_ts(end: datetime, start: datetime) -> tuple[int, int]: + if not end: + end = datetime.now() + end_ts = int(end.timestamp()) + start_ts = int(start.timestamp()) + if end_ts < start_ts: + raise ValueError(f"end_time: {end} must not be before start_time: {start}") + return end_ts, start_ts + + @staticmethod + def _extract_metric_name(df: pd.DataFrame) -> Optional[str]: + try: + metric_name = df["metric.__name__"].item() + except ValueError as err: + metric_names = df["metric.__name__"].unique() + if len(metric_names) > 1: + raise PrometheusInvalidResponseError( + f"More than 1 metric names " f"were extracted in the query: {metric_names}" + ) from err + return metric_names[0] + except KeyError: + LOGGER.warning("Could not infer metric name from results") + return None + return metric_name + + def _api_query_range(self, query: str, start_ts: int, end_ts: int) -> list[dict]: + """Queries Prometheus API for data.""" "" + try: + response = requests.get( + self._endpoint, + params={ + "query": query, + "start": start_ts, + "end": end_ts, + "step": f"{self._step_secs}s", + }, + ) + except Exception as err: + raise PrometheusFetcherError("Error while fetching data from Prometheus") from err + + try: + results = orjson.loads(response.text)["data"]["result"] + except (KeyError, JSONDecodeError) as err: + LOGGER.exception("Invalid response from Prometheus: %s", response.text) + raise PrometheusInvalidResponseError("Invalid response from Prometheus") from err + return results + + def query_range(self, query: str, start_ts: int, end_ts: int, _depth=0) -> list[dict]: + """ + Queries Prometheus API recursively for data. + + Args: + ------- + query: Prometheus query string + start_ts: Start timestamp + end_ts: End timestamp + _depth: Current recursion depth + + Returns + ------- + List of Prometheus results + + Raises + ------ + RecursionError: If the recursive depth exceeds the max depth + """ + if _depth > _MAX_RECURSION_DEPTH: + raise RecursionError(f"Max recursive depth of {_depth} reached") + datapoints = (end_ts - start_ts) // self._step_secs + + if datapoints > MAX_DATA_POINTS: + max_end_ts = start_ts + (MAX_DATA_POINTS * self._step_secs) + + with ThreadPoolExecutor(max_workers=2) as _executor: + results = _executor.map( + self.query_range, + (query, query), + (start_ts, max_end_ts + self._step_secs), + (max_end_ts, end_ts), + (_depth + 1, _depth + 1), + ) + return reduce(iconcat, results, []) + return self._api_query_range(query, start_ts, end_ts) diff --git a/numalogic/tools/exceptions.py b/numalogic/tools/exceptions.py index b8f2b05b..0b22f563 100644 --- a/numalogic/tools/exceptions.py +++ b/numalogic/tools/exceptions.py @@ -80,3 +80,21 @@ class EnvVarNotFoundError(LookupError): """Raised when an environment variable is not found.""" pass + + +class PrometheusFetcherError(Exception): + """Base class for all exceptions raised by the PrometheusFetcher class.""" + + pass + + +class PrometheusInvalidResponseError(PrometheusFetcherError): + """Raised when the Prometheus response is not a success.""" + + pass + + +class DataFormatError(Exception): + """Raised when the data format is not valid.""" + + pass diff --git a/numalogic/tools/types.py b/numalogic/tools/types.py index 22214b3c..3e55587d 100644 --- a/numalogic/tools/types.py +++ b/numalogic/tools/types.py @@ -19,7 +19,7 @@ try: from redis.client import AbstractRedis except ImportError: - pass + redis_client_t = TypeVar("redis_client_t") else: redis_client_t = TypeVar("redis_client_t", bound=AbstractRedis, covariant=True) @@ -38,6 +38,7 @@ KEYS = TypeVar("KEYS", bound=Sequence[str], covariant=True) + class KeyedArtifact(NamedTuple): r"""namedtuple for artifacts.""" diff --git a/numalogic/udfs/__main__.py b/numalogic/udfs/__main__.py index 9e030ea1..6751b7e1 100644 --- a/numalogic/udfs/__main__.py +++ b/numalogic/udfs/__main__.py @@ -13,7 +13,9 @@ "CONF_PATH", default=os.path.join(BASE_CONF_DIR, "default-configs", "config.yaml") ) -if __name__ == "__main__": + +def start_server(): + """Starts the pynumaflow server.""" set_logger() step = sys.argv[1] @@ -32,3 +34,7 @@ server = ServerFactory.get_server_instance(server_type, map_handler=udf) server.start() + + +if __name__ == "__main__": + start_server() diff --git a/numalogic/udfs/_base.py b/numalogic/udfs/_base.py index 2fe69cae..fd121ab6 100644 --- a/numalogic/udfs/_base.py +++ b/numalogic/udfs/_base.py @@ -65,8 +65,9 @@ async def aexec(self, keys: list[str], datum: Datum) -> Messages: """ raise NotImplementedError("aexec method not implemented") + @classmethod @abstractmethod - def compute(self, model: artifact_t, input_: npt.NDArray[float], **kwargs): + def compute(cls, model: artifact_t, input_: npt.NDArray[float], **kwargs): """ Abstract method to be implemented by subclasses. diff --git a/numalogic/udfs/factory.py b/numalogic/udfs/factory.py index 4de03612..056e4689 100644 --- a/numalogic/udfs/factory.py +++ b/numalogic/udfs/factory.py @@ -1,6 +1,7 @@ import logging from typing import ClassVar +from numalogic.udfs import NumalogicUDF from numalogic.udfs.inference import InferenceUDF from numalogic.udfs.trainer import TrainerUDF from numalogic.udfs.preprocess import PreprocessUDF @@ -21,7 +22,7 @@ class UDFFactory: } @classmethod - def get_udf_cls(cls, udf_name: str): + def get_udf_cls(cls, udf_name: str) -> type[NumalogicUDF]: try: return cls._UDF_MAP[udf_name] except KeyError as err: @@ -30,7 +31,7 @@ def get_udf_cls(cls, udf_name: str): raise ValueError(_msg) from err @classmethod - def get_udf_instance(cls, udf_name: str, **kwargs): + def get_udf_instance(cls, udf_name: str, **kwargs) -> NumalogicUDF: udf_cls = cls.get_udf_cls(udf_name) return udf_cls(**kwargs) diff --git a/numalogic/udfs/inference.py b/numalogic/udfs/inference.py index 46daadee..285ca539 100644 --- a/numalogic/udfs/inference.py +++ b/numalogic/udfs/inference.py @@ -10,7 +10,8 @@ from orjson import orjson from pynumaflow.function import Messages, Datum, Message -from numalogic.registry import RedisRegistry, LocalLRUCache, ArtifactData +from numalogic.config import RegistryFactory +from numalogic.registry import LocalLRUCache, ArtifactData from numalogic.tools.exceptions import ConfigNotFoundError from numalogic.tools.types import artifact_t, redis_client_t from numalogic.udfs._base import NumalogicUDF @@ -37,7 +38,8 @@ class InferenceUDF(NumalogicUDF): def __init__(self, r_client: redis_client_t, pl_conf: Optional[PipelineConf] = None): super().__init__(is_async=False) - self.model_registry = RedisRegistry( + model_registry_cls = RegistryFactory.get_cls("RedisRegistry") + self.model_registry = model_registry_cls( client=r_client, cache_registry=LocalLRUCache(ttl=LOCAL_CACHE_TTL, cachesize=LOCAL_CACHE_SIZE), ) @@ -60,7 +62,8 @@ def get_conf(self, config_id: str) -> StreamConf: except KeyError as err: raise ConfigNotFoundError(f"Config with ID {config_id} not found!") from err - def compute(self, model: artifact_t, input_: npt.NDArray[float], **_) -> npt.NDArray[float]: + @classmethod + def compute(cls, model: artifact_t, input_: npt.NDArray[float], **_) -> npt.NDArray[float]: """ Perform inference on the input data. @@ -81,7 +84,7 @@ def compute(self, model: artifact_t, input_: npt.NDArray[float], **_) -> npt.NDA try: with torch.no_grad(): _, out = model.forward(x) - recon_err = model.criterion(out, x, reduction="none") + recon_err = model.criterion(out, x, reduction="none") except Exception as err: raise RuntimeError("Model forward pass failed!") from err return np.ascontiguousarray(recon_err).squeeze(0) diff --git a/numalogic/udfs/postprocess.py b/numalogic/udfs/postprocess.py index 338e6c46..f365039b 100644 --- a/numalogic/udfs/postprocess.py +++ b/numalogic/udfs/postprocess.py @@ -9,8 +9,8 @@ from orjson import orjson from pynumaflow.function import Messages, Datum, Message -from numalogic.config import PostprocessFactory -from numalogic.registry import LocalLRUCache, RedisRegistry +from numalogic.config import PostprocessFactory, RegistryFactory +from numalogic.registry import LocalLRUCache from numalogic.tools.exceptions import ConfigNotFoundError from numalogic.tools.types import redis_client_t, artifact_t from numalogic.udfs import NumalogicUDF @@ -41,9 +41,9 @@ def __init__( pl_conf: Optional[PipelineConf] = None, ): super().__init__() - self.model_registry = RedisRegistry( - client=r_client, - cache_registry=LocalLRUCache(ttl=LOCAL_CACHE_TTL, cachesize=LOCAL_CACHE_SIZE), + model_registry_cls = RegistryFactory.get_cls("RedisRegistry") + self.model_registry = model_registry_cls( + client=r_client, cache_registry=LocalLRUCache(ttl=LOCAL_CACHE_TTL) ) self.pl_conf = pl_conf or PipelineConf() self.postproc_factory = PostprocessFactory() @@ -161,8 +161,9 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: ) return messages + @classmethod def compute( - self, model: artifact_t, input_: NDArray[float], postproc_clf=None, **_ + cls, model: artifact_t, input_: NDArray[float], postproc_clf=None, **_ ) -> NDArray[float]: """ Compute the postprocess function. diff --git a/numalogic/udfs/preprocess.py b/numalogic/udfs/preprocess.py index b84d91f8..1b0e3dd5 100644 --- a/numalogic/udfs/preprocess.py +++ b/numalogic/udfs/preprocess.py @@ -9,8 +9,8 @@ from pynumaflow.function import Datum, Messages, Message from sklearn.pipeline import make_pipeline -from numalogic.config import PreprocessFactory -from numalogic.registry import LocalLRUCache, RedisRegistry +from numalogic.config import PreprocessFactory, RegistryFactory +from numalogic.registry import LocalLRUCache from numalogic.tools.exceptions import ConfigNotFoundError from numalogic.tools.types import redis_client_t, artifact_t from numalogic.udfs import NumalogicUDF @@ -37,9 +37,9 @@ class PreprocessUDF(NumalogicUDF): def __init__(self, r_client: redis_client_t, pl_conf: Optional[PipelineConf] = None): super().__init__() - self.model_registry = RedisRegistry( - client=r_client, - cache_registry=LocalLRUCache(cachesize=LOCAL_CACHE_SIZE, ttl=LOCAL_CACHE_TTL), + model_registry_cls = RegistryFactory.get_cls("RedisRegistry") + self.model_registry = model_registry_cls( + client=r_client, cache_registry=LocalLRUCache(ttl=LOCAL_CACHE_TTL) ) self.pl_conf = pl_conf or PipelineConf() self.preproc_factory = PreprocessFactory() @@ -171,8 +171,9 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: ) return Messages(Message(keys=keys, value=payload.to_json())) + @classmethod def compute( - self, model: artifact_t, input_: Optional[NDArray[float]] = None, **_ + cls, model: artifact_t, input_: Optional[NDArray[float]] = None, **_ ) -> NDArray[float]: """ Perform inference on the input data. diff --git a/numalogic/udfs/trainer.py b/numalogic/udfs/trainer.py index f35b6ca9..2b34d099 100644 --- a/numalogic/udfs/trainer.py +++ b/numalogic/udfs/trainer.py @@ -12,18 +12,16 @@ from torch.utils.data import DataLoader from numalogic.base import StatelessTransformer -from numalogic.config import PreprocessFactory, ModelFactory, ThresholdFactory +from numalogic.config import PreprocessFactory, ModelFactory, ThresholdFactory, RegistryFactory from numalogic.config._config import TrainerConf -from numalogic.connectors.druid import DruidFetcher +from numalogic.config.factory import ConnectorFactory from numalogic.models.autoencoder import AutoencoderTrainer -from numalogic.registry import RedisRegistry from numalogic.tools.data import StreamingDataset from numalogic.tools.exceptions import ConfigNotFoundError, RedisRegistryError from numalogic.tools.types import redis_client_t, artifact_t, KEYS, KeyedArtifact from numalogic.udfs import NumalogicUDF from numalogic.udfs._config import StreamConf, PipelineConf from numalogic.udfs.entities import TrainerPayload, StreamPayload - _LOGGER = logging.getLogger(__name__) @@ -43,12 +41,14 @@ def __init__( ): super().__init__(is_async=False) self.r_client = r_client - self.model_registry = RedisRegistry(client=r_client) + model_registry_cls = RegistryFactory.get_cls("RedisRegistry") + self.model_registry = model_registry_cls(client=r_client) self.pl_conf = pl_conf or PipelineConf() self.druid_conf = self.pl_conf.druid_conf + data_fetcher_cls = ConnectorFactory.get_cls("DruidFetcher") try: - self.data_fetcher = DruidFetcher( + self.data_fetcher = data_fetcher_cls( url=self.druid_conf.url, endpoint=self.druid_conf.endpoint ) except AttributeError: @@ -89,8 +89,9 @@ def get_conf(self, config_id: str) -> StreamConf: except KeyError as err: raise ConfigNotFoundError(f"Config with ID {config_id} not found!") from err + @classmethod def compute( - self, + cls, model: artifact_t, input_: npt.NDArray[float], preproc_clf: Optional[artifact_t] = None, @@ -217,12 +218,10 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: ) return Messages(Message.to_drop()) - @staticmethod - def _construct_preproc_clf(_conf: StreamConf) -> Optional[artifact_t]: - preproc_factory = PreprocessFactory() + def _construct_preproc_clf(self, _conf: StreamConf) -> Optional[artifact_t]: preproc_clfs = [] for _cfg in _conf.numalogic_conf.preprocess: - _clf = preproc_factory.get_instance(_cfg) + _clf = self._preproc_factory.get_instance(_cfg) preproc_clfs.append(_clf) if not preproc_clfs: return None @@ -234,7 +233,7 @@ def _construct_preproc_clf(_conf: StreamConf) -> Optional[artifact_t]: def artifacts_to_save( skeys: KEYS, dict_artifacts: dict[str, KeyedArtifact], - model_registry: RedisRegistry, + model_registry, payload: StreamPayload, ) -> None: """ @@ -308,7 +307,7 @@ def fetch_data(self, payload: TrainerPayload) -> pd.DataFrame: _conf = self.get_conf(payload.config_id) try: - _df = self.data_fetcher.fetch_data( + _df = self.data_fetcher.fetch( datasource=self.druid_conf.fetcher.datasource, filter_keys=_conf.composite_keys, filter_values=payload.composite_keys, diff --git a/poetry.lock b/poetry.lock index c4f89292..45e0604c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -112,7 +112,7 @@ speedups = ["Brotli", "aiodns", "cchardet"] name = "aiorun" version = "2022.11.1" description = "Boilerplate for asyncio applications" -optional = true +optional = false python-versions = ">=3.5" files = [ {file = "aiorun-2022.11.1-py3-none-any.whl", hash = "sha256:8fbfc2aab258021deef2b1f38284c652af9fd3710e94c7b0e736a55d161fa0cb"}, @@ -415,17 +415,17 @@ css = ["tinycss2 (>=1.1.0,<1.2)"] [[package]] name = "boto3" -version = "1.28.47" +version = "1.28.48" description = "The AWS SDK for Python" optional = false python-versions = ">= 3.7" files = [ - {file = "boto3-1.28.47-py3-none-any.whl", hash = "sha256:27560da44099e7e2ee961d3971d8ea659de2e0dc24e78043d1c3027d89b2d8a2"}, - {file = "boto3-1.28.47.tar.gz", hash = "sha256:be69cd28e3732b63ad61f6d2429b1eac92428588911a5a7367faa4e129a4738d"}, + {file = "boto3-1.28.48-py3-none-any.whl", hash = "sha256:ec7895504e3b2dd35fbdb7397bc3c48daaba8e6f37bc436aa928ff4e745f0f1c"}, + {file = "boto3-1.28.48.tar.gz", hash = "sha256:fed2d673fce33384697baa0028edfd18b06aa17af5c3ef82da75e9254a8ffb07"}, ] [package.dependencies] -botocore = ">=1.31.47,<1.32.0" +botocore = ">=1.31.48,<1.32.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.6.0,<0.7.0" @@ -434,13 +434,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.31.47" +version = "1.31.48" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">= 3.7" files = [ - {file = "botocore-1.31.47-py3-none-any.whl", hash = "sha256:6a60f9601270458102529b17fdcba5551b918f9eedc32bbc2f467e63edfb2662"}, - {file = "botocore-1.31.47.tar.gz", hash = "sha256:a0ba5629eb17a37bf449bccda9df6ae652d5755f73145519d5eb244f6963b31b"}, + {file = "botocore-1.31.48-py3-none-any.whl", hash = "sha256:9618c06f7e08ed590dae6613b8b2511055f7d6c07517382143ef8563169d4ef1"}, + {file = "botocore-1.31.48.tar.gz", hash = "sha256:6ed16f66aa6ed6070fed26d69764cb14c7759e4cc0b1c191283cc48b05d65de9"}, ] [package.dependencies] @@ -1332,7 +1332,7 @@ test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mypy", "pre-commit" name = "google-api-core" version = "2.11.1" description = "Google API client core library" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "google-api-core-2.11.1.tar.gz", hash = "sha256:25d29e05a0058ed5f19c61c0a78b1b53adea4d9364b464d014fbda941f6d1c9a"}, @@ -1354,7 +1354,7 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] name = "google-auth" version = "2.23.0" description = "Google Authentication Library" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "google-auth-2.23.0.tar.gz", hash = "sha256:753a26312e6f1eaeec20bc6f2644a10926697da93446e1f8e24d6d32d45a922a"}, @@ -1378,7 +1378,7 @@ requests = ["requests (>=2.20.0,<3.0.0.dev0)"] name = "google-cloud" version = "0.34.0" description = "API Client library for Google Cloud" -optional = true +optional = false python-versions = "*" files = [ {file = "google-cloud-0.34.0.tar.gz", hash = "sha256:01430187cf56df10a9ba775dd547393185d4b40741db0ea5889301f8e7a9d5d3"}, @@ -1389,7 +1389,7 @@ files = [ name = "googleapis-common-protos" version = "1.60.0" description = "Common protobufs used in Google APIs" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "googleapis-common-protos-1.60.0.tar.gz", hash = "sha256:e73ebb404098db405ba95d1e1ae0aa91c3e15a71da031a2eeb6b2e23e7bc3708"}, @@ -1406,7 +1406,7 @@ grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] name = "grpcio" version = "1.58.0" description = "HTTP/2-based RPC framework" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "grpcio-1.58.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:3e6bebf1dfdbeb22afd95650e4f019219fef3ab86d3fca8ebade52e4bc39389a"}, @@ -1463,7 +1463,7 @@ protobuf = ["grpcio-tools (>=1.58.0)"] name = "grpcio-tools" version = "1.58.0" description = "Protobuf code generator for gRPC" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "grpcio-tools-1.58.0.tar.gz", hash = "sha256:6f4d80ceb591e31ca4dceec747dbe56132e1392a0a9bb1c8fe001d1b5cac898a"}, @@ -1925,7 +1925,6 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, - {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] @@ -2426,16 +2425,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -2564,13 +2553,13 @@ files = [ [[package]] name = "mlflow-skinny" -version = "2.7.0" +version = "2.6.0" description = "MLflow: A Platform for ML Development and Productionization" optional = true python-versions = ">=3.8" files = [ - {file = "mlflow-skinny-2.7.0.tar.gz", hash = "sha256:c4ce24817227219ac03f674bd3d0f371b274807e7e3845b3d46ba79ca86dad08"}, - {file = "mlflow_skinny-2.7.0-py3-none-any.whl", hash = "sha256:0dd3855b07bfce5f1bc403b6505e84f0e165780aba5d7edcca2024cff7ca7309"}, + {file = "mlflow-skinny-2.6.0.tar.gz", hash = "sha256:c1a71bc4abb83169d4cff012c1c050c203775d4a7462d059c4523a481f0a9688"}, + {file = "mlflow_skinny-2.6.0-py3-none-any.whl", hash = "sha256:b63c6555645af0196b3c92da849968189dfd72d2c1397f7b149593db6b404d66"}, ] [package.dependencies] @@ -2591,7 +2580,7 @@ sqlparse = ">=0.4.0,<1" aliyun-oss = ["aliyunstoreplugin"] databricks = ["azure-storage-file-datalake (>12)", "boto3 (>1)", "google-cloud-storage (>=1.30.0)"] extras = ["azureml-core (>=1.2.0)", "boto3", "google-cloud-storage (>=1.30.0)", "kubernetes", "mlserver (>=1.2.0,!=1.3.1)", "mlserver-mlflow (>=1.2.0,!=1.3.1)", "prometheus-flask-exporter", "pyarrow", "pysftp", "requests-auth-aws-sigv4", "virtualenv"] -gateway = ["aiohttp (<4)", "fastapi (<1)", "psutil (<6)", "pydantic (>=1.0,<3)", "uvicorn[standard] (<1)", "watchfiles (<1)"] +gateway = ["aiohttp (<4)", "fastapi (<1)", "psutil (<6)", "pydantic (>=1.0,<2)", "uvicorn[standard] (<1)", "watchfiles (<1)"] sqlserver = ["mlflow-dbstore"] [[package]] @@ -3313,7 +3302,6 @@ files = [ {file = "Pillow-10.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3b08d4cc24f471b2c8ca24ec060abf4bebc6b144cb89cba638c720546b1cf538"}, {file = "Pillow-10.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d737a602fbd82afd892ca746392401b634e278cb65d55c4b7a8f48e9ef8d008d"}, {file = "Pillow-10.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:3a82c40d706d9aa9734289740ce26460a11aeec2d9c79b7af87bb35f0073c12f"}, - {file = "Pillow-10.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:bc2ec7c7b5d66b8ec9ce9f720dbb5fa4bace0f545acd34870eff4a369b44bf37"}, {file = "Pillow-10.0.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:d80cf684b541685fccdd84c485b31ce73fc5c9b5d7523bf1394ce134a60c6883"}, {file = "Pillow-10.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:76de421f9c326da8f43d690110f0e79fe3ad1e54be811545d7d91898b4c8493e"}, {file = "Pillow-10.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81ff539a12457809666fef6624684c008e00ff6bf455b4b89fd00a140eecd640"}, @@ -3323,7 +3311,6 @@ files = [ {file = "Pillow-10.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d50b6aec14bc737742ca96e85d6d0a5f9bfbded018264b3b70ff9d8c33485551"}, {file = "Pillow-10.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:00e65f5e822decd501e374b0650146063fbb30a7264b4d2744bdd7b913e0cab5"}, {file = "Pillow-10.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:f31f9fdbfecb042d046f9d91270a0ba28368a723302786c0009ee9b9f1f60199"}, - {file = "Pillow-10.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:1ce91b6ec08d866b14413d3f0bbdea7e24dfdc8e59f562bb77bc3fe60b6144ca"}, {file = "Pillow-10.0.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:349930d6e9c685c089284b013478d6f76e3a534e36ddfa912cde493f235372f3"}, {file = "Pillow-10.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3a684105f7c32488f7153905a4e3015a3b6c7182e106fe3c37fbb5ef3e6994c3"}, {file = "Pillow-10.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b4f69b3700201b80bb82c3a97d5e9254084f6dd5fb5b16fc1a7b974260f89f43"}, @@ -3437,7 +3424,7 @@ wcwidth = "*" name = "protobuf" version = "4.24.3" description = "" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "protobuf-4.24.3-cp310-abi3-win32.whl", hash = "sha256:20651f11b6adc70c0f29efbe8f4a94a74caf61b6200472a9aea6e19898f9fcf4"}, @@ -3510,7 +3497,7 @@ tests = ["pytest"] name = "pyasn1" version = "0.5.0" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" -optional = true +optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ {file = "pyasn1-0.5.0-py2.py3-none-any.whl", hash = "sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57"}, @@ -3521,7 +3508,7 @@ files = [ name = "pyasn1-modules" version = "0.3.0" description = "A collection of ASN.1-based protocols modules" -optional = true +optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ {file = "pyasn1_modules-0.3.0-py2.py3-none-any.whl", hash = "sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d"}, @@ -3544,12 +3531,13 @@ files = [ [[package]] name = "pydruid" -version = "0.999.0dev" -description = "" +version = "0.6.5" +description = "A Python connector for Druid." optional = true python-versions = "*" -files = [] -develop = false +files = [ + {file = "pydruid-0.6.5.tar.gz", hash = "sha256:30b6efa722234183239296bf8e6d41aba7eea0eec3d68233ce6c413986864fea"}, +] [package.dependencies] requests = "*" @@ -3560,12 +3548,6 @@ cli = ["prompt_toolkit (>=2.0.0)", "pygments", "tabulate"] pandas = ["pandas"] sqlalchemy = ["sqlalchemy"] -[package.source] -type = "git" -url = "https://github.com/cosmic-chichu/pydruid.git" -reference = "master" -resolved_reference = "0ce16bcfd2dd666ff5294469cc47263acafca1d7" - [[package]] name = "pygments" version = "2.16.1" @@ -3612,7 +3594,7 @@ files = [ name = "pynumaflow" version = "0.4.2" description = "Provides the interfaces of writing Python User Defined Functions and Sinks for NumaFlow." -optional = true +optional = false python-versions = ">=3.9,<3.12" files = [ {file = "pynumaflow-0.4.2-py3-none-any.whl", hash = "sha256:354e44bff87578bb43df43e70d326d82be1f75280dfdaad6f8ae3d46608e886b"}, @@ -3708,13 +3690,13 @@ files = [ [[package]] name = "pytorch-lightning" -version = "2.0.8" +version = "2.0.9" description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate." optional = false python-versions = ">=3.8" files = [ - {file = "pytorch-lightning-2.0.8.tar.gz", hash = "sha256:fb7e8fbe473158b2c1666b6b31fb996c8aa2c3f5e8e2a54a3f50a7b5b2d00a20"}, - {file = "pytorch_lightning-2.0.8-py3-none-any.whl", hash = "sha256:718d11f22551d95ef38614b4727433553c95ea2b50cf843938fb13baf34325a6"}, + {file = "pytorch-lightning-2.0.9.tar.gz", hash = "sha256:a4cbfe5256192b315c76f864d3f9d60bd04aace361950e83326b29565b212387"}, + {file = "pytorch_lightning-2.0.9-py3-none-any.whl", hash = "sha256:719adb50dfe5fe0a51f04be040c36f29aefc9845b62e2c6bc7bb6db05348c31a"}, ] [package.dependencies] @@ -3729,13 +3711,13 @@ tqdm = ">=4.57.0" typing-extensions = ">=4.0.0" [package.extras] -all = ["deepspeed (>=0.8.2)", "gym[classic-control] (>=0.17.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.14.1)", "jsonargparse[signatures] (>=4.18.0,<4.23.0)", "lightning-utilities (>=0.7.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=12.3.0)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.12.0)"] +all = ["deepspeed (>=0.8.2)", "gym[classic-control] (>=0.17.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.15.0)", "jsonargparse[signatures] (>=4.18.0,<4.25.0)", "lightning-utilities (>=0.7.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=12.3.0)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.12.0)"] deepspeed = ["deepspeed (>=0.8.2)"] -dev = ["cloudpickle (>=1.3)", "coverage (==7.3.0)", "deepspeed (>=0.8.2)", "fastapi (<0.100.0)", "gym[classic-control] (>=0.17.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.14.1)", "jsonargparse[signatures] (>=4.18.0,<4.23.0)", "lightning-utilities (>=0.7.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "onnx (<1.15.0)", "onnxruntime (<1.16.0)", "pandas (>1.0)", "protobuf (<=3.20.1)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-forked (==1.4.0)", "pytest-rerunfailures (==10.3)", "rich (>=12.3.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.12.0)", "uvicorn (<0.23.3)"] -examples = ["gym[classic-control] (>=0.17.0)", "ipython[all] (<8.14.1)", "lightning-utilities (>=0.7.0)", "torchmetrics (>=0.10.0)", "torchvision (>=0.12.0)"] -extra = ["hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.18.0,<4.23.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=12.3.0)", "tensorboardX (>=2.2)"] +dev = ["cloudpickle (>=1.3)", "coverage (==7.3.0)", "deepspeed (>=0.8.2)", "fastapi (<0.100.0)", "gym[classic-control] (>=0.17.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.15.0)", "jsonargparse[signatures] (>=4.18.0,<4.25.0)", "lightning-utilities (>=0.7.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "onnx (<1.15.0)", "onnxruntime (<1.16.0)", "pandas (>1.0)", "protobuf (<3.21.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-forked (==1.4.0)", "pytest-rerunfailures (==10.3)", "rich (>=12.3.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.12.0)", "uvicorn (<0.23.3)"] +examples = ["gym[classic-control] (>=0.17.0)", "ipython[all] (<8.15.0)", "lightning-utilities (>=0.7.0)", "torchmetrics (>=0.10.0)", "torchvision (>=0.12.0)"] +extra = ["hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.18.0,<4.25.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=12.3.0)", "tensorboardX (>=2.2)"] strategies = ["deepspeed (>=0.8.2)"] -test = ["cloudpickle (>=1.3)", "coverage (==7.3.0)", "fastapi (<0.100.0)", "onnx (<1.15.0)", "onnxruntime (<1.16.0)", "pandas (>1.0)", "protobuf (<=3.20.1)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-forked (==1.4.0)", "pytest-rerunfailures (==10.3)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "uvicorn (<0.23.3)"] +test = ["cloudpickle (>=1.3)", "coverage (==7.3.0)", "fastapi (<0.100.0)", "onnx (<1.15.0)", "onnxruntime (<1.16.0)", "pandas (>1.0)", "protobuf (<3.21.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-forked (==1.4.0)", "pytest-rerunfailures (==10.3)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "uvicorn (<0.23.3)"] [[package]] name = "pytz" @@ -3797,7 +3779,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -3805,15 +3786,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -3830,7 +3804,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -3838,7 +3811,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -4202,7 +4174,7 @@ files = [ name = "rsa" version = "4.9" description = "Pure-Python RSA implementation" -optional = true +optional = false python-versions = ">=3.6,<4" files = [ {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, @@ -4690,18 +4662,39 @@ telegram = ["requests"] [[package]] name = "traitlets" -version = "5.9.0" +version = "5.10.0" description = "Traitlets Python configuration system" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "traitlets-5.9.0-py3-none-any.whl", hash = "sha256:9e6ec080259b9a5940c797d58b613b5e31441c2257b87c2e795c5228ae80d2d8"}, - {file = "traitlets-5.9.0.tar.gz", hash = "sha256:f6cde21a9c68cf756af02035f72d5a723bf607e862e7be33ece505abf4a3bad9"}, + {file = "traitlets-5.10.0-py3-none-any.whl", hash = "sha256:417745a96681fbb358e723d5346a547521f36e9bd0d50ba7ab368fff5d67aa54"}, + {file = "traitlets-5.10.0.tar.gz", hash = "sha256:f584ea209240466e66e91f3c81aa7d004ba4cf794990b0c775938a1544217cd1"}, ] [package.extras] docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] -test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] +test = ["argcomplete (>=3.0.3)", "mypy (>=1.5.1)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] + +[[package]] +name = "typer" +version = "0.9.0" +description = "Typer, build great CLIs. Easy to code. Based on Python type hints." +optional = false +python-versions = ">=3.6" +files = [ + {file = "typer-0.9.0-py3-none-any.whl", hash = "sha256:5d96d986a21493606a358cae4461bd8cdf83cbf33a5aa950ae629ca3b51467ee"}, + {file = "typer-0.9.0.tar.gz", hash = "sha256:50922fd79aea2f4751a8e0408ff10d2662bd0c8bbfa84755a699f3bada2978b2"}, +] + +[package.dependencies] +click = ">=7.1.1,<9.0.0" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] +doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"] +test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] [[package]] name = "types-pyyaml" @@ -4984,10 +4977,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p druid = ["pydruid"] dynamodb = ["boto3"] mlflow = ["mlflow-skinny"] -numaflow = ["pynumaflow"] redis = ["redis"] [metadata] lock-version = "2.0" python-versions = ">=3.9, <3.12" -content-hash = "88c108c2c19c66d5fe5042c2ba5c996ba9057e4be5cda949e4a176efd516e0a0" +content-hash = "292d29c73f389c027915552f126fb3dc3ebeaca536ca7244ffa83d94841483a6" diff --git a/pyproject.toml b/pyproject.toml index d0975a7d..c709b900 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "numalogic" -version = "0.6.0a3" +version = "0.6.0a1" description = "Collection of operational Machine Learning models and tools." authors = ["Numalogic Developers"] packages = [{ include = "numalogic" }] @@ -23,6 +23,9 @@ repository = "https://github.com/numaproj/numalogic" documentation = "https://numalogic.numaproj.io/" homepage = "https://numalogic.numaproj.io/" +[tool.poetry.scripts] +numalogic ="numalogic.backtest.__main__:app" + [tool.poetry.dependencies] python = ">=3.9, <3.12" numpy = "^1.23" @@ -31,18 +34,20 @@ scikit-learn = "^1.2" omegaconf = "^2.3.0" cachetools = "^5.3.0" orjson = "^3.9.5" +typer = {version = "^0.9", extras = ["rich"]} +matplotlib = "^3.4.2" +pynumaflow = "~0.4" + # extras -mlflow-skinny = { version = ">2.0, <3.0", optional = true } +mlflow-skinny = { version = "~2.6", optional = true } redis = {extras = ["hiredis"], version = "^5.0", optional = true} -pynumaflow = { version = "~0.4", optional = true } boto3 = { version = "^1.24.64", optional = true } -pydruid = {git = "https://github.com/cosmic-chichu/pydruid.git", rev = "master", optional =true} +pydruid = {version = "^0.6", optional = true} [tool.poetry.extras] mlflow = ["mlflow-skinny"] redis = ["redis"] -numaflow = ["pynumaflow"] dynamodb = ["boto3"] druid = ["pydruid"] @@ -56,7 +61,6 @@ pytorch-lightning = "^2.0" optional = true [tool.poetry.group.dev.dependencies] -matplotlib = "^3.4.2" black = "^23.0" pytest = "^7.1" pytest-cov = "^4.0" diff --git a/tests/config/test_config.py b/tests/config/test_config.py index a33b1e4b..25e115ab 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -90,7 +90,8 @@ def test_instance(self): def test_cls(self): factory = ModelFactory() - model_cls = factory.get_cls(ModelInfo(name="LSTMAE")) + model_cls = factory.get_cls("LSTMAE") + print(model_cls) self.assertEqual(model_cls.__class__, LSTMAE.__class__) def test_instance_err(self): @@ -101,7 +102,7 @@ def test_instance_err(self): def test_cls_err(self): factory = ModelFactory() with self.assertRaises(UnknownConfigArgsError): - factory.get_cls(ModelInfo(name="Random")) + factory.get_cls("Random") if __name__ == "__main__": diff --git a/tests/config/test_optdeps.py b/tests/config/test_optdeps.py index 6b50ebfa..4d8115c5 100644 --- a/tests/config/test_optdeps.py +++ b/tests/config/test_optdeps.py @@ -30,7 +30,7 @@ def test_not_installed_dep_01(self, _): server = fakeredis.FakeServer() redis_cli = fakeredis.FakeStrictRedis(server=server, decode_responses=False) with self.assertRaises(ImportError): - model_factory.get_cls(self.regconf)(redis_cli, **self.regconf.conf) + model_factory.get_cls("RedisRegistry")(redis_cli, **self.regconf.conf) @patch("numalogic.config.factory.getattr", side_effect=AttributeError) def test_not_installed_dep_02(self, _): @@ -48,7 +48,7 @@ def test_unknown_registry(self): model_factory = RegistryFactory() reg_conf = RegistryInfo(name="UnknownRegistry") with self.assertRaises(UnknownConfigArgsError): - model_factory.get_cls(reg_conf) + model_factory.get_cls("UnknownRegistry") with self.assertRaises(UnknownConfigArgsError): model_factory.get_instance(reg_conf) diff --git a/tests/connectors/test_druid.py b/tests/connectors/test_druid.py index f7831950..4f2060f3 100644 --- a/tests/connectors/test_druid.py +++ b/tests/connectors/test_druid.py @@ -44,7 +44,7 @@ def setUpClass(cls) -> None: @patch.object(PyDruid, "groupby", Mock(return_value=mock_group_by())) def test_fetch_data(self): - _out = self.druid.fetch_data( + _out = self.druid.fetch( filter_keys=["assetId"], filter_values=["5984175597303660107"], dimensions=["ciStatus"], diff --git a/tests/connectors/test_prometheus.py b/tests/connectors/test_prometheus.py new file mode 100644 index 00000000..9405d3db --- /dev/null +++ b/tests/connectors/test_prometheus.py @@ -0,0 +1,180 @@ +import logging +import unittest +from datetime import datetime, timedelta +from unittest.mock import patch, Mock, MagicMock + +from orjson import orjson +from requests import Response + +from numalogic.connectors import PrometheusFetcher +from numalogic.tools.exceptions import PrometheusFetcherError, PrometheusInvalidResponseError + +logging.basicConfig(level=logging.DEBUG) + + +def _mock_response(): + response = MagicMock() + response.status_code = 200 + response.text = orjson.dumps( + { + "status": "success", + "data": { + "resultType": "vector", + "result": [ + { + "metric": { + "__name__": "namespace_asset_pod_cpu_utilization", + "numalogic": "true", + "namespace": "sandbox-numalogic-demo", + }, + "values": [ + [1656334767.73, "14.744611739611193"], + [1656334797.73, "14.73040822323633"], + ], + } + ], + }, + } + ) + return response + + +def _mock_query_range(): + return [ + { + "metric": { + "__name__": "namespace_asset_pod_cpu_utilization", + "assetAlias": "sandbox.numalogic.demo", + "numalogic": "true", + "namespace": "sandbox-numalogic-demo", + }, + "values": [[1656334767.73, "14.744611739611193"], [1656334797.73, "14.73040822323633"]], + } + ] + + +def _mock_w_return_labels(): + return [ + { + "metric": { + "__name__": "namespace_app_rollouts_http_request_error_rate", + "assetAlias": "sandbox.numalogic.demo", + "numalogic": "true", + "namespace": "sandbox-numalogic-demo", + "rollouts_pod_template_hash": "7b4b4f9f9d", + }, + "values": [[1656334767.73, "10.0"], [1656334797.73, "12.0"]], + }, + { + "metric": { + "__name__": "namespace_app_rollouts_http_request_error_rate", + "assetAlias": "sandbox.numalogic.demo", + "numalogic": "true", + "namespace": "sandbox-numalogic-demo", + "rollouts_pod_template_hash": "5b4b4f9f9d", + }, + "values": [[1656334767.73, "11.0"], [1656334797.73, "13.0"]], + }, + ] + + +class TestPrometheusFetcher(unittest.TestCase): + def setUp(self) -> None: + self.fetcher = PrometheusFetcher(prometheus_server="http://localhost:9090") + + @patch("requests.get", Mock(return_value=_mock_response())) + def test_fetch(self): + df = self.fetcher.fetch( + metric_name="namespace_asset_pod_cpu_utilization", + start=datetime.now() - timedelta(hours=1), + filters={"namespace": "sandbox-numalogic-demo"}, + ) + self.assertEqual(df.shape, (2, 2)) + self.assertListEqual( + df.columns.to_list(), ["timestamp", "namespace_asset_pod_cpu_utilization"] + ) + + @patch.object(PrometheusFetcher, "_api_query_range", Mock(return_value=_mock_w_return_labels())) + def test_fetch_return_labels(self): + metric = "namespace_app_rollouts_http_request_error_rate" + df = self.fetcher.fetch( + metric_name=metric, + start=datetime.now() - timedelta(hours=1), + filters={"namespace": "sandbox-numalogic-demo"}, + return_labels=["rollouts_pod_template_hash"], + aggregate=True, + ) + self.assertEqual(df.shape, (2, 2)) + self.assertListEqual(df.columns.to_list(), ["timestamp", metric]) + self.assertListEqual([10.5, 12.5], df[metric].to_list()) + + @patch.object(PrometheusFetcher, "_api_query_range", Mock(return_value=[])) + def test_fetch_no_data(self): + df = self.fetcher.fetch( + metric_name="namespace_asset_pod_cpu_utilization", + start=datetime.now() - timedelta(hours=1), + filters={"namespace": "sandbox-numalogic-demo"}, + ) + self.assertTrue(df.empty) + + @patch("requests.get", Mock(side_effect=Exception("Test exception"))) + def test_fetch_url_err(self): + with self.assertRaises(PrometheusFetcherError): + self.fetcher.fetch( + metric_name="namespace_asset_pod_cpu_utilization", + start=datetime.now() - timedelta(hours=1), + filters={"namespace": "sandbox-numalogic-demo"}, + ) + + @patch("requests.get", Mock(return_value=Response())) + def test_fetch_response_err(self): + with self.assertRaises(PrometheusInvalidResponseError): + self.fetcher.fetch( + metric_name="namespace_asset_pod_cpu_utilization", + start=datetime.now() - timedelta(hours=1), + filters={"namespace": "sandbox-numalogic-demo"}, + ) + + @patch.object(PrometheusFetcher, "_api_query_range", Mock(return_value=_mock_query_range())) + def test_fetch_raw(self): + df = self.fetcher.raw_fetch( + query='namespace_asset_pod_cpu_utilization{namespace="sandbox-numalogic-demo"}', + start=datetime.now() - timedelta(hours=1), + ) + self.assertEqual(df.shape, (2, 2)) + self.assertListEqual( + df.columns.to_list(), ["timestamp", "namespace_asset_pod_cpu_utilization"] + ) + + @patch.object(PrometheusFetcher, "_api_query_range", Mock(return_value=_mock_w_return_labels())) + def test_fetch_raw_return_labels(self): + metric = "namespace_app_rollouts_http_request_error_rate" + df = self.fetcher.raw_fetch( + query="namespace_app_rollouts_http_request_error_rate{namespace='sandbox-numalogic-demo'}", + start=datetime.now() - timedelta(hours=1), + return_labels=["rollouts_pod_template_hash"], + aggregate=True, + ) + self.assertEqual(df.shape, (2, 2)) + self.assertListEqual(df.columns.to_list(), ["timestamp", metric]) + self.assertListEqual([10.5, 12.5], df[metric].to_list()) + + @patch.object(PrometheusFetcher, "_api_query_range", Mock(return_value=[])) + def test_fetch_raw_no_data(self): + df = self.fetcher.raw_fetch( + query='namespace_asset_pod_cpu_utilization{namespace="sandbox-numalogic-demo"}', + start=datetime.now() - timedelta(hours=1), + ) + self.assertTrue(df.empty) + + def test_start_end_err(self): + with self.assertRaises(ValueError): + self.fetcher.fetch( + metric_name="namespace_asset_pod_cpu_utilization", + start=datetime.now() - timedelta(hours=1), + end=datetime.now() - timedelta(hours=2), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/synthetic/test_anomalies.py b/tests/synthetic/test_anomalies.py index 5955a4b0..1e089c93 100644 --- a/tests/synthetic/test_anomalies.py +++ b/tests/synthetic/test_anomalies.py @@ -19,11 +19,7 @@ def test_inject_global_anomalies(self, plot=False): outlier_df[cols].plot(ax=ax1, title="Outlier") ax2 = plt.subplot(212) test_df[cols].plot(ax=ax2, title="Original") - plt.show() - # self.assertNotEqual(test_df.shape, outlier_df.shape) - # other_cols = test_df.columns.difference(cols) - # self.assertTrue(test_df[other_cols].equals(outlier_df[other_cols])) - # self.assertFalse(test_df.equals(outlier_df)) + plt.show() def test_inject_contextual_anomalies(self, plot=False): ts_generator = SyntheticTSGenerator(7200, 5, seasonal_ts_prob=1.0) @@ -40,7 +36,7 @@ def test_inject_contextual_anomalies(self, plot=False): outlier_df[cols].plot(ax=ax1, title="Outlier") ax2 = plt.subplot(212) test_df[cols].plot(ax=ax2, title="Original") - plt.show() + plt.show() self.assertNotEqual(test_df.shape, outlier_df.shape) self.assertFalse(test_df.equals(outlier_df)) other_cols = test_df.columns.difference(cols) @@ -59,7 +55,7 @@ def test_inject_collective_anomalies(self, plot=False): outlier_df[cols].plot(ax=ax1, title="Outlier") ax2 = plt.subplot(212) test_df[cols].plot(ax=ax2, title="Original") - plt.show() + plt.show() self.assertNotEqual(test_df.shape, outlier_df.shape) self.assertFalse(test_df.equals(outlier_df)) other_cols = test_df.columns.difference(cols) @@ -78,7 +74,7 @@ def test_inject_causal_anomalies(self, plot=False): outlier_df[cols].plot(ax=ax1, title="Outlier") ax2 = plt.subplot(212) test_df[cols].plot(ax=ax2, title="Original") - plt.show() + plt.show() self.assertNotEqual(test_df.shape, outlier_df.shape) self.assertFalse(test_df.equals(outlier_df)) other_cols = test_df.columns.difference(cols) diff --git a/tests/test_backtest.py b/tests/test_backtest.py new file mode 100644 index 00000000..f9d47f11 --- /dev/null +++ b/tests/test_backtest.py @@ -0,0 +1,173 @@ +import os.path +import shutil +import unittest +from unittest.mock import patch, Mock + +import pandas as pd +from typer.testing import CliRunner +from numalogic.backtest._bt import app as btapp +from numalogic.backtest.__main__ import app as rootapp +from numalogic._constants import TESTS_DIR +from numalogic.connectors import PrometheusFetcher +from numalogic.tools.exceptions import DataFormatError + +runner = CliRunner() + + +def _mock_datafetch(): + df = pd.read_csv( + os.path.join(TESTS_DIR, "resources", "data", "interactionstatus.csv"), + usecols=["ts", "failure"], + nrows=1000, + ) + df.rename(columns={"ts": "timestamp"}, inplace=True, errors="raise") + return df + + +class TestBacktest(unittest.TestCase): + def tearDown(self) -> None: + _dir = os.path.join(TESTS_DIR, ".btoutput") + if os.path.exists(_dir): + shutil.rmtree(_dir, ignore_errors=False, onerror=None) + + @patch.object(PrometheusFetcher, "fetch", Mock(return_value=_mock_datafetch())) + def test_univar(self): + res = runner.invoke(btapp, ["univariate", "ns1", "app1", "failure"], catch_exceptions=False) + self.assertEqual(0, res.exit_code) + self.assertIsNone(res.exception) + + def test_multivar(self): + with self.assertRaises(NotImplementedError): + runner.invoke(btapp, ["multivariate"], catch_exceptions=False) + + +class TestRoot(unittest.TestCase): + def tearDown(self) -> None: + _dir = os.path.join(TESTS_DIR, ".btoutput") + if os.path.exists(_dir): + shutil.rmtree(_dir, ignore_errors=False, onerror=None) + + def test_train(self): + data_path = os.path.join(TESTS_DIR, "resources", "data", "interactionstatus.csv") + res = runner.invoke( + rootapp, + [ + "train", + "--data-file", + data_path, + "--col-name", + "failure", + "--ts-col-name", + "ts", + "--train-ratio", + "0.1", + ], + catch_exceptions=False, + ) + self.assertEqual(0, res.exit_code) + self.assertIsNone(res.exception) + + def test_train_err(self): + data_path = os.path.join(TESTS_DIR, "resources", "data", "interactionstatus.csv") + with self.assertRaises(DataFormatError): + runner.invoke( + rootapp, + [ + "train", + "--data-file", + data_path, + "--col-name", + "failure", + "--ts-col-name", + "timestamp", + "--train-ratio", + "0.1", + ], + catch_exceptions=False, + ) + + def test_train_arg_err(self): + res = runner.invoke( + rootapp, + [ + "train", + "--col-name", + "failure", + "--ts-col-name", + "timestamp", + "--train-ratio", + "0.1", + ], + catch_exceptions=False, + ) + self.assertEqual(1, res.exit_code) + self.assertIsNotNone(res.exception) + + def test_score(self): + data_path = os.path.join(TESTS_DIR, "resources", "data", "interactionstatus.csv") + res = runner.invoke( + rootapp, + [ + "score", + "--data-file", + data_path, + "--col-name", + "failure", + "--ts-col-name", + "ts", + "--test-ratio", + "0.5", + ], + catch_exceptions=False, + ) + self.assertEqual(0, res.exit_code) + self.assertIsNone(res.exception) + + def test_score_err(self): + data_path = os.path.join(TESTS_DIR, "resources", "data", "interactionstatus.csv") + with self.assertRaises(DataFormatError): + runner.invoke( + rootapp, + [ + "score", + "--data-file", + data_path, + "--col-name", + "failure", + "--ts-col-name", + "timestamp", + "--test-ratio", + "0.1", + ], + catch_exceptions=False, + ) + + def test_score_arg_err(self): + res = runner.invoke( + rootapp, + [ + "score", + "--col-name", + "failure", + "--ts-col-name", + "timestamp", + "--test-ratio", + "0.1", + ], + catch_exceptions=False, + ) + self.assertEqual(1, res.exit_code) + self.assertIsNotNone(res.exception) + + def test_clear(self): + res = runner.invoke( + rootapp, + ["clear", "--output-dir", os.path.join(TESTS_DIR, ".btoutput"), "--all"], + catch_exceptions=False, + ) + self.assertEqual(0, res.exit_code) + self.assertIsNone(res.exception) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/udfs/test_numaflow.py b/tests/udfs/test_numaflow.py index 6c5f0bdc..9eba47b7 100644 --- a/tests/udfs/test_numaflow.py +++ b/tests/udfs/test_numaflow.py @@ -17,7 +17,8 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: val = datum.value return Messages(Message(value=val, keys=keys)) - def compute(self, model: artifact_t, input_: npt.NDArray[float], **kwargs): + @classmethod + def compute(cls, model: artifact_t, input_: npt.NDArray[float], **kwargs): pass @@ -29,7 +30,8 @@ async def aexec(self, keys: list[str], datum: Datum) -> Messages: val = datum.value return Messages(Message(value=val, keys=keys)) - def compute(self, model: artifact_t, input_: npt.NDArray[float], **kwargs): + @classmethod + def compute(cls, model: artifact_t, input_: npt.NDArray[float], **kwargs): pass diff --git a/tests/udfs/test_trainer.py b/tests/udfs/test_trainer.py index 3e7ffb9e..828fc6bb 100644 --- a/tests/udfs/test_trainer.py +++ b/tests/udfs/test_trainer.py @@ -58,7 +58,7 @@ def setUp(self): def tearDown(self) -> None: REDIS_CLIENT.flushall() - @patch.object(DruidFetcher, "fetch_data", Mock(return_value=mock_druid_fetch_data())) + @patch.object(DruidFetcher, "fetch", Mock(return_value=mock_druid_fetch_data())) def test_trainer_01(self): self.udf.register_conf( "druid-config", @@ -78,7 +78,7 @@ def test_trainer_01(self): ), ) - @patch.object(DruidFetcher, "fetch_data", Mock(return_value=mock_druid_fetch_data())) + @patch.object(DruidFetcher, "fetch", Mock(return_value=mock_druid_fetch_data())) def test_trainer_02(self): self.udf.register_conf( "druid-config", @@ -100,7 +100,7 @@ def test_trainer_02(self): ), ) - @patch.object(DruidFetcher, "fetch_data", Mock(return_value=mock_druid_fetch_data())) + @patch.object(DruidFetcher, "fetch", Mock(return_value=mock_druid_fetch_data())) def test_trainer_03(self): self.udf.register_conf( "druid-config", @@ -127,7 +127,7 @@ def test_trainer_conf_err(self): with self.assertRaises(ConfigNotFoundError): udf(KEYS, self.datum) - @patch.object(DruidFetcher, "fetch_data", Mock(return_value=mock_druid_fetch_data(nrows=10))) + @patch.object(DruidFetcher, "fetch", Mock(return_value=mock_druid_fetch_data(nrows=10))) def test_trainer_data_insufficient(self): self.udf.register_conf( "druid-config", @@ -148,7 +148,7 @@ def test_trainer_data_insufficient(self): ) ) - @patch.object(DruidFetcher, "fetch_data", Mock(side_effect=RuntimeError)) + @patch.object(DruidFetcher, "fetch", Mock(side_effect=RuntimeError)) def test_trainer_datafetcher_err(self): self.udf.register_conf( "druid-config",