Skip to content

Commit

Permalink
chore: address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <ab93@users.noreply.github.com>
  • Loading branch information
ab93 committed Aug 30, 2023
1 parent be5183a commit 3581a82
Show file tree
Hide file tree
Showing 15 changed files with 115 additions and 91 deletions.
4 changes: 2 additions & 2 deletions numalogic/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ class TrainerConf:
train_hours: int = 24 * 8 # 8 days worth of data
min_train_size: int = 2000
retrain_freq_hr: int = 24
model_expiry_sec: int = 86400 # 24 hrs
dedup_expiry_sec: int = 1800 # 30 days
model_expiry_sec: int = 86400 # 24 hrs # TODO: revisit this
dedup_expiry_sec: int = 1800 # 30 days # TODO: revisit this
batch_size: int = 32
pltrainer_conf: LightningTrainerConf = field(default_factory=LightningTrainerConf)

Expand Down
9 changes: 7 additions & 2 deletions numalogic/connectors/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from enum import IntEnum

from omegaconf import MISSING
from pydruid.utils.aggregators import doublesum


class ConnectorType(IntEnum):
Expand Down Expand Up @@ -40,11 +39,17 @@ class Pivot:
class DruidFetcherConf:
datasource: str
dimensions: list[str] = field(default_factory=list)
aggregations: dict = field(default_factory=lambda: {"count": doublesum("count")})
aggregations: dict = field(default_factory=dict)
group_by: list[str] = field(default_factory=list)
pivot: Pivot = field(default_factory=lambda: Pivot())
granularity: str = "minute"

def __post_init__(self):
from pydruid.utils.aggregators import doublesum

if not self.aggregations:
self.aggregations = {"count": doublesum("count")}


@dataclass
class DruidConf(ConnectorConf):
Expand Down
2 changes: 1 addition & 1 deletion numalogic/connectors/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def fetch_data(
if group_by:
df = df.groupby(by=group_by).sum().reset_index()

if pivot:
if pivot.columns:
df = df.pivot(
index=pivot.index,
columns=pivot.columns,
Expand Down
7 changes: 6 additions & 1 deletion numalogic/connectors/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

from numalogic.connectors._config import RedisConf
from numalogic.tools.exceptions import EnvVarNotFoundError
from numalogic.tools.types import redis_client_t
from redis.backoff import ExponentialBackoff
from redis.exceptions import RedisClusterException, RedisError
Expand Down Expand Up @@ -73,7 +74,7 @@ def get_redis_client(
return SENTINEL_CLIENT


def get_redis_client_from_conf(redis_conf: Optional[RedisConf] = None, **kwargs) -> redis_client_t:
def get_redis_client_from_conf(redis_conf: RedisConf, **kwargs) -> redis_client_t:
"""
Return a master redis client from config for sentinel connections, with retry.
Expand All @@ -85,6 +86,10 @@ def get_redis_client_from_conf(redis_conf: Optional[RedisConf] = None, **kwargs)
-------
Redis client instance
"""
auth = os.getenv("REDIS_AUTH")

Check warning on line 89 in numalogic/connectors/redis.py

View check run for this annotation

Codecov / codecov/patch

numalogic/connectors/redis.py#L89

Added line #L89 was not covered by tests
if not auth:
raise EnvVarNotFoundError("REDIS_AUTH not set!")

Check warning on line 91 in numalogic/connectors/redis.py

View check run for this annotation

Codecov / codecov/patch

numalogic/connectors/redis.py#L91

Added line #L91 was not covered by tests

return get_redis_client(

Check warning on line 93 in numalogic/connectors/redis.py

View check run for this annotation

Codecov / codecov/patch

numalogic/connectors/redis.py#L93

Added line #L93 was not covered by tests
redis_conf.url,
redis_conf.port,
Expand Down
2 changes: 1 addition & 1 deletion numalogic/registry/redis_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class RedisRegistry(ArtifactManager):
>>> from numalogic.models.autoencoder.variants import VanillaAE
>>> from numalogic.registry.redis_registry import RedisRegistry
>>> ...
>>> r = redis.StrictRedis(host='127.0.0.1', port=6379)
>>> r = redis.Redis(host='127.0.0.1', port=6379)
>>> registry = RedisRegistry(client=r)
>>> skeys, dkeys = ("mymetric", "ae"), ("vanilla", "seq10")
>>> model = VanillaAE(seq_len=10)
Expand Down
6 changes: 6 additions & 0 deletions numalogic/tools/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,9 @@ class ModelKeyNotFound(RedisRegistryError):
"""Raised when a model key is not found in the registry."""

pass


class EnvVarNotFoundError(LookupError):
"""Raised when an environment variable is not found."""

pass
2 changes: 2 additions & 0 deletions numalogic/udfs/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from numalogic.udfs import load_pipeline_conf, UDFFactory, ServerFactory, set_logger

Check warning on line 7 in numalogic/udfs/__main__.py

View check run for this annotation

Codecov / codecov/patch

numalogic/udfs/__main__.py#L5-L7

Added lines #L5 - L7 were not covered by tests

LOGGER = logging.getLogger(__name__)

Check warning on line 9 in numalogic/udfs/__main__.py

View check run for this annotation

Codecov / codecov/patch

numalogic/udfs/__main__.py#L9

Added line #L9 was not covered by tests

# TODO support user config paths
CONF_FILE_PATH = os.getenv(

Check warning on line 12 in numalogic/udfs/__main__.py

View check run for this annotation

Codecov / codecov/patch

numalogic/udfs/__main__.py#L12

Added line #L12 was not covered by tests
"CONF_PATH", default=os.path.join(BASE_CONF_DIR, "default-configs", "config.yaml")
)
Expand Down
2 changes: 1 addition & 1 deletion numalogic/udfs/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@dataclass
class StreamConf:
config_id: str = "default"
source: ConnectorType = ConnectorType.druid
source: ConnectorType = ConnectorType.druid # TODO: do not allow redis connector here
window_size: int = 12
composite_keys: list[str] = field(default_factory=list)
metrics: list[str] = field(default_factory=list)
Expand Down
2 changes: 1 addition & 1 deletion numalogic/udfs/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class OutputPayload(_BasePayload):
"""Payload for output data from the numalogic pipeline."""

timestamp: int
unified_anomaly: float
unified_anomaly: float # TODO: change to a more generic name
data: dict[str, Any]
metadata: dict[str, Any]

Expand Down
5 changes: 5 additions & 0 deletions numalogic/udfs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from numalogic.udfs.entities import StreamPayload, Header, Status

_LOGGER = logging.getLogger(__name__)

# TODO: move to config
LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", "3600"))


Expand All @@ -37,6 +39,7 @@ def __init__(self, r_client: redis_client_t, pl_conf: Optional[PipelineConf] = N
)
self.pl_conf = pl_conf or PipelineConf()

# TODO: remove, and have an update config method
def register_conf(self, config_id: str, conf: StreamConf) -> None:
"""
Register config with the UDF.
Expand Down Expand Up @@ -108,6 +111,8 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
return Messages(Message(keys=keys, value=payload.to_json()))

artifact_data = self.load_artifact(keys, payload)

# TODO: revisit retraining logic
# Send training request if artifact loading is not successful
if not artifact_data:
payload = replace(
Expand Down
2 changes: 2 additions & 0 deletions numalogic/udfs/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from numalogic.udfs.entities import StreamPayload, Header, Status, TrainerPayload, OutputPayload
from numalogic.udfs.tools import _load_model

# TODO: move to config
LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", "3600"))
LOCAL_CACHE_SIZE = int(os.getenv("LOCAL_CACHE_SIZE", "10000"))

Expand Down Expand Up @@ -128,6 +129,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
data={
_metric: _score for _metric, _score in zip(payload.metrics, anomaly_scores)
},
# TODO: add model version, & emit as ML metrics
metadata=payload.metadata,
)
_LOGGER.info(
Expand Down
9 changes: 6 additions & 3 deletions numalogic/udfs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from numalogic.udfs.entities import Status, Header
from numalogic.udfs.tools import make_stream_payload, get_df, _load_model

# TODO: move to config
LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", "3600"))
LOCAL_CACHE_SIZE = int(os.getenv("LOCAL_CACHE_SIZE", "10000"))

Expand Down Expand Up @@ -124,6 +125,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
preproc_artifact.extras.get("source"),
)
else:
# TODO check again what error is causing this and if retraining is required
payload = replace(
payload, status=Status.ARTIFACT_NOT_FOUND, header=Header.TRAIN_REQUEST
)
Expand All @@ -136,10 +138,10 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
self.get_conf(payload.config_id).numalogic_conf.preprocess
)
try:
processed_data = self.compute(model=preproc_clf, input_=payload.get_data())
x_scaled = self.compute(model=preproc_clf, input_=payload.get_data())
payload = replace(
payload,
data=processed_data,
data=x_scaled,
status=Status.ARTIFACT_FOUND,
header=Header.MODEL_INFERENCE,
)
Expand All @@ -148,7 +150,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
payload.uuid,
keys,
payload.metrics,
list(processed_data),
x_scaled,
)
except RuntimeError:
_LOGGER.exception(
Expand All @@ -157,6 +159,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
payload.composite_keys,
payload.metrics,
)
# TODO check again what error is causing this and if retraining is required
payload = replace(payload, status=Status.RUNTIME_ERROR, header=Header.TRAIN_REQUEST)
return Messages(Message(keys=keys, value=payload.to_json()))
_LOGGER.debug(
Expand Down
1 change: 1 addition & 0 deletions numalogic/udfs/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def make_stream_payload(
)


# TODO: move to base NumalogicUDF class
def _load_model(
skeys: KEYS, dkeys: KEYS, payload: StreamPayload, model_registry: ArtifactManager
) -> Optional[ArtifactData]:
Expand Down
3 changes: 2 additions & 1 deletion numalogic/udfs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def _is_data_sufficient(self, payload: TrainerPayload, df: pd.DataFrame) -> bool
_conf = self.get_conf(payload.config_id)
return len(df) > _conf.numalogic_conf.trainer.min_train_size

# TODO: improve the dedup logic; this is too naive
def _is_new_request(self, payload: TrainerPayload) -> bool:
_conf = self.get_conf(payload.config_id)
_ckeys = ":".join(payload.composite_keys)
Expand All @@ -288,7 +289,7 @@ def get_feature_arr(
if col not in raw_df.columns:
raw_df[col] = fill_value

Check warning on line 290 in numalogic/udfs/trainer.py

View check run for this annotation

Codecov / codecov/patch

numalogic/udfs/trainer.py#L290

Added line #L290 was not covered by tests
feat_df = raw_df[metrics]
feat_df.fillna(fill_value, inplace=True)
feat_df = feat_df.fillna(fill_value)
return feat_df.to_numpy(dtype=np.float32)

def fetch_data(self, payload: TrainerPayload) -> pd.DataFrame:
Expand Down
Loading

0 comments on commit 3581a82

Please sign in to comment.