Skip to content

Commit

Permalink
feat: trainer as multi proc udf
Browse files Browse the repository at this point in the history
Signed-off-by: nkoppisetty <nandita.iitkgp@gmail.com>
  • Loading branch information
nkoppisetty committed Jul 25, 2023
1 parent 86dbfc2 commit 4cb3d61
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 63 deletions.
6 changes: 3 additions & 3 deletions udf/anomaly-detection/src/factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.udf import Preprocess, Inference, Threshold, Postprocess
from src.udsink import Train
from src.udf.trainer import Trainer


class HandlerFactory:
Expand All @@ -17,7 +17,7 @@ def get_handler(cls, step: str):
if step == "postprocess":
return Postprocess().run

if step == "train":
return Train().run
if step == "trainer":
return Trainer().run

raise NotImplementedError(f"Invalid step provided: {step}")
3 changes: 2 additions & 1 deletion udf/anomaly-detection/src/udf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from src.udf.inference import Inference
from src.udf.threshold import Threshold
from src.udf.postprocess import Postprocess
from src.udf.trainer import Trainer

__all__ = ["Preprocess", "Inference", "Threshold", "Postprocess"]
__all__ = ["Preprocess", "Inference", "Threshold", "Postprocess", "Trainer"]
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
import orjson
import pandas as pd
from typing import List, Iterator
from typing import List
from numalogic.config import (
NumalogicConf,
ThresholdFactory,
Expand All @@ -16,7 +16,7 @@
from numalogic.tools.exceptions import RedisRegistryError
from numalogic.tools.types import redis_client_t
from omegaconf import OmegaConf
from pynumaflow.sink import Datum, Responses, Response
from pynumaflow.function import Datum, Messages, Message
from sklearn.pipeline import make_pipeline
from torch.utils.data import DataLoader

Expand All @@ -41,7 +41,7 @@ def get_feature_df(data: pd.DataFrame, metrics: list):
return data[metrics]


class Train:
class Trainer:
@classmethod
def fetch_prometheus_data(cls, payload: TrainerPayload) -> pd.DataFrame:
prometheus_conf = ConfigManager.get_prom_config()
Expand Down Expand Up @@ -235,47 +235,46 @@ def _train_and_save(
version,
)

def run(self, datums: Iterator[Datum]) -> Responses:
responses = Responses()
def run(self, keys: List[str], datum: Datum) -> Messages:
messages = Messages()
redis_client = get_redis_client_from_conf()

for _datum in datums:
payload = TrainerPayload(**orjson.loads(_datum.value))
is_new = self._is_new_request(redis_client, payload)

if not is_new:
responses.append(Response.as_success(_datum.id))
continue

retrain_config = ConfigManager.get_retrain_config(payload.config_id)
numalogic_config = ConfigManager.get_numalogic_config(payload.config_id)

try:
df = self.fetch_data(payload)
except Exception as err:
_LOGGER.error(
"%s - Error while fetching data for keys: %s, metrics: %s, err: %r",
payload.uuid,
payload.composite_keys,
payload.metrics,
err,
)
responses.append(Response.as_success(_datum.id))
continue

if len(df) < retrain_config.min_train_size:
_LOGGER.warning(
"%s - Skipping training, train data less than minimum required: %s, df shape: %s",
payload.uuid,
retrain_config.min_train_size,
df.shape,
)
responses.append(Response.as_success(_datum.id))
continue

train_df = get_feature_df(df, payload.metrics)
self._train_and_save(numalogic_config, payload, redis_client, train_df)

responses.append(Response.as_success(_datum.id))

return responses
payload = TrainerPayload(**orjson.loads(datum.value))
is_new = self._is_new_request(redis_client, payload)

if not is_new:
messages.append(Message.to_drop())
return messages

retrain_config = ConfigManager.get_retrain_config(payload.config_id)
numalogic_config = ConfigManager.get_numalogic_config(payload.config_id)

try:
df = self.fetch_data(payload)
except Exception as err:
_LOGGER.error(
"%s - Error while fetching data for keys: %s, metrics: %s, err: %r",
payload.uuid,
payload.composite_keys,
payload.metrics,
err,
)
messages.append(Message.to_drop())
return messages

if len(df) < retrain_config.min_train_size:
_LOGGER.warning(
"%s - Skipping training, train data less than minimum required: %s, df shape: %s",
payload.uuid,
retrain_config.min_train_size,
df.shape,
)
messages.append(Message.to_drop())
return messages

train_df = get_feature_df(df, payload.metrics)
self._train_and_save(numalogic_config, payload, redis_client, train_df)

messages.append(Message(keys=keys, value=train_df.to_json()))

return messages
3 changes: 0 additions & 3 deletions udf/anomaly-detection/src/udsink/__init__.py

This file was deleted.

5 changes: 4 additions & 1 deletion udf/anomaly-detection/starter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import threading

import aiorun
from pynumaflow.function import Server, AsyncServer
from pynumaflow.function import Server, AsyncServer, MultiProcServer
from pynumaflow.sink import Sink

from src._constants import CONFIG_PATHS
Expand All @@ -29,6 +29,9 @@ def run_watcher():
elif server_type == "udf":
server = Server(map_handler=step_handler)
server.start()
elif server_type == "multiproc_udf":
server = MultiProcServer(map_handler=step_handler)
server.start()
elif server_type == "async_udf":
server = AsyncServer(reduce_handler=step_handler)
aiorun.run(server.start())
Expand Down
5 changes: 2 additions & 3 deletions udf/anomaly-detection/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,15 @@ def mock_configs():
mock_get_redis_client_from_conf.return_value = redis_client
with patch.object(ConfigManager, "load_configs") as mock_confs:
mock_confs.return_value = mock_configs()
from src.udf import Preprocess, Inference, Threshold, Postprocess
from src.udsink import Train
from src.udf import Preprocess, Inference, Threshold, Postprocess, Trainer


__all__ = [
"redis_client",
"Train",
"Preprocess",
"Inference",
"Threshold",
"Postprocess",
"Trainer",
"mock_configs",
]
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
mock_prom_query_metric,
mock_druid_fetch_data,
)
from tests import redis_client, Train, mock_configs
from tests import redis_client, Trainer, mock_configs

DATA_DIR = os.path.join(TESTS_DIR, "resources", "data")
STREAM_DATA_PATH = os.path.join(DATA_DIR, "stream.json")
Expand Down Expand Up @@ -56,17 +56,22 @@ def setUp(self) -> None:
@patch.object(Prometheus, "query_metric", Mock(return_value=mock_prom_query_metric()))
@patch.object(ConfigManager, "load_configs", Mock(return_value=mock_configs()))
def test_prometheus_01(self):
_out = Train().run(datums=iter([as_datum(self.train_payload)]))
self.assertTrue(_out[0].success)
self.assertEqual("1", _out[0].id)
_out = Trainer().run(
keys=[
"sandbox_numalogic_demo",
"metric_1",
"123456789",
],
datum=as_datum(self.train_payload),
)
self.assertTrue(_out[0])

@patch.object(DruidFetcher, "fetch_data", Mock(return_value=mock_druid_fetch_data()))
@patch.object(ConfigManager, "load_configs", Mock(return_value=mock_configs()))
def test_druid_01(self):
_out = Train().run(datums=iter([as_datum(self.train_payload2)]))
_out = Trainer().run(keys=["5984175597303660107"], datum=as_datum(self.train_payload2))
print(_out)
self.assertTrue(_out[0].success)
self.assertEqual("1", _out[0].id)
self.assertTrue(_out[0])


if __name__ == "__main__":
Expand Down

0 comments on commit 4cb3d61

Please sign in to comment.