Skip to content

Commit

Permalink
fix: typos
Browse files Browse the repository at this point in the history
Signed-off-by: s0nicboOm <i.kushalbatra@gmail.com>
  • Loading branch information
s0nicboOm committed Sep 12, 2023
1 parent 7c3c8d3 commit d7c3e59
Show file tree
Hide file tree
Showing 14 changed files with 82 additions and 5,211 deletions.
Binary file removed dump.rdb
Binary file not shown.
2 changes: 1 addition & 1 deletion examples/multi_udf/src/udf/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
LOGGER.info("%s - Inference complete", payload.uuid)
else:
# If model not found, set status as not found
LOGGER.exception("%s - Model not found", payload.uuid)
LOGGER.warning("%s - Model not found", payload.uuid)
payload.is_artifact_valid = False

# Convert Payload back to bytes and conditional forward to threshold vertex
Expand Down
54 changes: 29 additions & 25 deletions numalogic/registry/redis_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,27 +159,33 @@ def __load_latest_artifact(self, key: str) -> tuple[ArtifactData, bool]:
------
ModelKeyNotFound: If the model key is not found in the registry.
"""
cached_artifact = self._load_from_cache(key)
latest_key = self.__construct_latest_key(key)
cached_artifact = self._load_from_cache(latest_key)
if cached_artifact:
_LOGGER.debug("Found cached artifact for key: %s", key)
_LOGGER.debug("Found cached artifact for key: %s", latest_key)
return cached_artifact, True
latest_key = self.__construct_latest_key(key)
if not self.client.exists(latest_key):
raise ModelKeyNotFound(f"latest key: {latest_key}, Not Found !!!")
model_key = self.client.get(latest_key)
_LOGGER.debug("latest key, %s, is pointing to the key : %s", latest_key, model_key)
return (
self.__load_version_artifact(version=self.get_version(model_key.decode()), key=key),
False,
artifact, _ = self.__load_version_artifact(
version=self.get_version(model_key.decode()), key=key
)
return artifact, False

def __load_version_artifact(self, version: str, key: str) -> ArtifactData:
model_key = self.__construct_version_key(key, version)
print(model_key)
if not self.client.exists(model_key):
raise ModelKeyNotFound("Could not find model key with key: %s" % model_key)
return self.__get_artifact_data(
model_key=model_key,
def __load_version_artifact(self, version: str, key: str) -> tuple[ArtifactData, bool]:
version_key = self.__construct_version_key(key, version)
cached_artifact = self._load_from_cache(version_key)
if cached_artifact:
_LOGGER.debug("Found cached version artifact for key: %s", version_key)
return cached_artifact, True
if not self.client.exists(version_key):
raise ModelKeyNotFound("Could not find model key with key: %s" % version_key)
return (
self.__get_artifact_data(
model_key=version_key,
),
False,
)

def __save_artifact(
Expand Down Expand Up @@ -237,25 +243,25 @@ def load(
if (latest and version) or (not latest and not version):
raise ValueError("Either One of 'latest' or 'version' needed in load method call")
key = self.construct_key(skeys, dkeys)
is_cached = False
try:
if latest:
artifact_data, is_cached = self.__load_latest_artifact(key)
else:
artifact_data = self.__load_version_artifact(version, key)
artifact_data, is_cached = self.__load_version_artifact(version, key)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
else:
if not is_cached:
if latest:
_LOGGER.info("Saving %s, in cache as %s", self.__construct_latest_key(key), key)
self._save_in_cache(self.__construct_latest_key(key), artifact_data)
else:
_LOGGER.info(
"Saving %s, in cache as %s",
self.__construct_version_key(key, version),
key,
)
self._save_in_cache(key, artifact_data)
self._save_in_cache(self.__construct_version_key(key, version), artifact_data)
return artifact_data

def save(
Expand Down Expand Up @@ -298,9 +304,8 @@ def save(
pipe=redis_pipe, artifact=artifact, key=key, version=str(version), **metadata
)
redis_pipe.expire(name=new_version_key, time=self.ttl)
# if pipe is None:
# print(redis_pipe.command_stack)
# redis_pipe.execute()
if pipe is None:
redis_pipe.execute()
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
else:
Expand Down Expand Up @@ -384,6 +389,9 @@ def save_multiple(
with self.client.pipeline(transaction=self.transactional) as pipe:
pipe.multi()
for count, (key, artifact) in enumerate(zip(list_dkeys, list_artifacts)):
dict_model_ver[":".join(key)] = self.save(
skeys=skeys, dkeys=key, artifact=artifact, pipe=pipe, **metadata
)
if count == len(list_artifacts) - 1:
self.save(
skeys=skeys,
Expand All @@ -393,15 +401,11 @@ def save_multiple(
artifact_versions=dict_model_ver,
**metadata,
)
break
dict_model_ver = {
":".join(key): self.save(
skeys=skeys, dkeys=key, artifact=artifact, pipe=pipe, **metadata
)
}
pipe.execute()
_LOGGER.info("Successfully saved all the artifacts with: %s", dict_model_ver)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
except IndexError as index_err:
raise RedisRegistryError(f"{index_err.__class__.__name__} raised") from index_err
else:
return dict_model_ver
5 changes: 5 additions & 0 deletions numalogic/udfs/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
CONF_FILE_PATH = os.getenv(
"CONF_PATH", default=os.path.join(BASE_CONF_DIR, "default-configs", "config.yaml")
)
pipeline_conf = load_pipeline_conf(CONF_FILE_PATH)
logging.info("Pipeline config: %s", pipeline_conf)

redis_client = get_redis_client_from_conf(pipeline_conf.redis_conf)

for key in redis_client.scan_iter("*"):
redis_client.delete(key)
if __name__ == "__main__":
set_logger()
step = sys.argv[1]
Expand Down
6 changes: 4 additions & 2 deletions numalogic/udfs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
from numalogic.udfs._base import NumalogicUDF
from numalogic.udfs._config import StreamConf, PipelineConf
from numalogic.udfs.entities import StreamPayload, Header, Status
from numalogic.udfs.tools import _load_model
from numalogic.udfs.tools import _load_artifact

_LOGGER = logging.getLogger(__name__)

# TODO: move to config
LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", "3600"))
LOCAL_CACHE_SIZE = int(os.getenv("LOCAL_CACHE_SIZE", "10000"))
LOAD_LATEST = os.getenv("LOAD_LATEST", "false").lower() == "true"


class InferenceUDF(NumalogicUDF):
Expand Down Expand Up @@ -112,11 +113,12 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
# Forward payload if a training request is tagged
if payload.header == Header.TRAIN_REQUEST:
return Messages(Message(keys=keys, value=payload.to_json()))
artifact_data = _load_model(
artifact_data, payload = _load_artifact(
skeys=keys,
dkeys=[self.get_conf(payload.config_id).numalogic_conf.model.name],
payload=payload,
model_registry=self.model_registry,
load_latest=LOAD_LATEST,
)

# TODO: revisit retraining logic
Expand Down
11 changes: 8 additions & 3 deletions numalogic/udfs/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from numalogic.udfs import NumalogicUDF
from numalogic.udfs._config import StreamConf, PipelineConf
from numalogic.udfs.entities import StreamPayload, Header, Status, TrainerPayload, OutputPayload
from numalogic.udfs.tools import _load_model
from numalogic.udfs.tools import _load_artifact

# TODO: move to config
LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", "3600"))
LOCAL_CACHE_SIZE = int(os.getenv("LOCAL_CACHE_SIZE", "10000"))
LOAD_LATEST = os.getenv("LOAD_LATEST", "false").lower() == "true"

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,8 +89,12 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
postprocess_cfg = self.get_conf(payload.config_id).numalogic_conf.postprocess

# load artifact
thresh_artifact = _load_model(
skeys=keys, dkeys=[thresh_cfg.name], payload=payload, model_registry=self.model_registry
thresh_artifact, payload = _load_artifact(
skeys=keys,
dkeys=[thresh_cfg.name],
payload=payload,
model_registry=self.model_registry,
load_latest=LOAD_LATEST,
)
postproc_clf = self.postproc_factory.get_instance(postprocess_cfg)

Expand Down
6 changes: 4 additions & 2 deletions numalogic/udfs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from numalogic.udfs import NumalogicUDF
from numalogic.udfs._config import StreamConf, PipelineConf
from numalogic.udfs.entities import Status, Header
from numalogic.udfs.tools import make_stream_payload, get_df, _load_model
from numalogic.udfs.tools import make_stream_payload, get_df, _load_artifact

# TODO: move to config
LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", "3600"))
LOCAL_CACHE_SIZE = int(os.getenv("LOCAL_CACHE_SIZE", "10000"))
LOAD_LATEST = os.getenv("LOAD_LATEST", "false").lower() == "true"

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,14 +109,15 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
if any(
[_conf.stateful for _conf in self.get_conf(payload.config_id).numalogic_conf.preprocess]
):
preproc_artifact = _load_model(
preproc_artifact, payload = _load_artifact(
skeys=keys,
dkeys=[
_conf.name
for _conf in self.get_conf(payload.config_id).numalogic_conf.preprocess
],
payload=payload,
model_registry=self.model_registry,
load_latest=LOAD_LATEST,
)
if preproc_artifact:
preproc_clf = preproc_artifact.artifact
Expand Down
34 changes: 23 additions & 11 deletions numalogic/udfs/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,13 @@ def make_stream_payload(


# TODO: move to base NumalogicUDF class
def _load_model(
skeys: KEYS, dkeys: KEYS, payload: StreamPayload, model_registry: ArtifactManager
) -> Optional[ArtifactData]:
def _load_artifact(
skeys: KEYS,
dkeys: KEYS,
payload: StreamPayload,
model_registry: ArtifactManager,
load_latest: bool,
) -> tuple[Optional[ArtifactData], StreamPayload]:
"""
Load artifact from redis
Args:
Expand All @@ -80,27 +84,35 @@ def _load_model(
Returns
-------
artifact_t object
StreamPayload object
"""
version_to_load = "-1"
str_dkeys = ":".join(dkeys)
if payload.metadata and "artifact_versions" in payload.metadata:
version_to_load = payload.metadata["artifact_versions"][str_dkeys]
version_to_load = payload.metadata["artifact_versions"][":".join(dkeys)]
_LOGGER.info("%s - Found version info for keys: %s, %s", payload.uuid, skeys, dkeys)
else:
_LOGGER.info(
"%s - No version info passed on! Loading latest artifact version for Keys: %s",
payload.uuid,
skeys,
)
load_latest = True
try:
if version_to_load == -1:
if load_latest:
artifact = model_registry.load(skeys=skeys, dkeys=dkeys)
else:
artifact = model_registry.load(
skeys=skeys, dkeys=dkeys, latest=False, version=version_to_load
)
except RedisRegistryError:
_LOGGER.exception(
_LOGGER.warning(
"%s - Error while fetching artifact, Keys: %s, Metrics: %s",
payload.uuid,
skeys,
payload.metrics,
)
return None
return None, payload

except Exception:
_LOGGER.exception(
Expand All @@ -109,7 +121,7 @@ def _load_model(
payload.composite_keys,
payload.metrics,
)
return None
return None, payload
else:
_LOGGER.info(
"%s - Loaded Model. Source: %s , version: %s, Keys: %s, %s",
Expand All @@ -120,11 +132,11 @@ def _load_model(
dkeys,
)
if artifact.metadata and "artifact_versions" in artifact.metadata:
replace(
payload = replace(
payload,
metadata={
"artifact_versions": artifact.metadata["artifact_versions"],
**payload.metadata,
},
)
return artifact
return artifact, payload
2 changes: 1 addition & 1 deletion numalogic/udfs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
# Save in reverse order so that version info is stored for other models in the
# metadata of first model
list_artifacts = [artifacts["threshold_clf"], artifacts["model"], artifacts["preproc_clf"]]
list_dkeys = [thresh_dkeys, preproc_dkeys, model_dkeys]
list_dkeys = [thresh_dkeys, model_dkeys, preproc_dkeys]
list_dkeys_save, list_artifacts_save = self.artifacts_to_save(
list_dkeys=list_dkeys, list_artifacts=list_artifacts
)
Expand Down
Loading

0 comments on commit d7c3e59

Please sign in to comment.