Skip to content

Commit

Permalink
Update logger object without passing as parameter
Browse files Browse the repository at this point in the history
Signed-off-by: Gulshan Bhatia <gulshan_bhatia@intuit.com>
  • Loading branch information
Gulshan Bhatia committed Jun 14, 2024
1 parent 6407e29 commit b511956
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 64 deletions.
44 changes: 15 additions & 29 deletions numalogic/udfs/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pandas import DataFrame
from pynumaflow.mapper import Message


from numalogic.registry import ArtifactManager, ArtifactData
from numalogic.tools.exceptions import RedisRegistryError
from numalogic.tools.types import KEYS, redis_client_t
Expand Down Expand Up @@ -252,12 +253,11 @@ def __construct_train_key(keys: KEYS) -> str:
return f"TRAIN::{_key}"

def __fetch_ts(self, uuid: str, key: str) -> _DedupMetadata:
logger = _struct_log.bind(uuid=uuid, key=key)
try:
data = self.client.hgetall(key)
except Exception:
_struct_log.exception(
"Problem fetching ts information for the key", uuid=uuid, key=key
)
logger.exception("Problem fetching ts information for the key")
return _DedupMetadata(msg_read_ts=None, msg_train_ts=None, msg_train_records=None)
else:
# decode the key:value pair and update the values
Expand All @@ -273,7 +273,7 @@ def __fetch_ts(self, uuid: str, key: str) -> _DedupMetadata:
msg_train_records=_msg_train_records,
)

def ack_insufficient_data(self, logger, key: KEYS, uuid: str, train_records: int) -> bool:
def ack_insufficient_data(self, key: KEYS, uuid: str, train_records: int) -> bool:
"""
Acknowledge the insufficient data message. Retry training after certain period of a time.
Args:
Expand All @@ -286,21 +286,17 @@ def ack_insufficient_data(self, logger, key: KEYS, uuid: str, train_records: int
bool.
"""
_key = self.__construct_train_key(key)
logger = _struct_log.bind(uuid=uuid, key=key)
try:
self.client.hset(name=_key, key="_msg_train_records", value=str(train_records))
except Exception:
logger.exception(
"Problem while updating _msg_train_records information for the key",
uuid=uuid,
key=key,
)
logger.exception("Problem while updating _msg_train_records information for the key")
return False
logger.debug("Acknowledging insufficient data for the key", uuid=uuid, key=key)
logger.debug("Acknowledging insufficient data for the key")
return True

def ack_read(
self,
logger,
key: KEYS,
uuid: str,
retrain_freq: int = 24,
Expand All @@ -325,6 +321,7 @@ def ack_read(
"""
_key = self.__construct_train_key(key)
metadata = self.__fetch_ts(uuid=uuid, key=_key)
logger = _struct_log.bind(uuid=uuid, key=key)
_msg_read_ts, _msg_train_ts, _msg_train_records = (
metadata.msg_read_ts,
metadata.msg_train_ts,
Expand All @@ -341,8 +338,6 @@ def ack_read(
logger.debug(
"There was insufficient data for the key in the past. Retrying fetching"
" and training after secs",
uuid=uuid,
key=key,
secs=((min_train_records - int(_msg_train_records)) * data_freq)
- _curr_time
+ float(_msg_read_ts),
Expand All @@ -352,31 +347,25 @@ def ack_read(

# Check if the model is being trained by another process
if _msg_read_ts and time.time() - float(_msg_read_ts) < retry:
logger.debug("Model with key is being trained by another process", uuid=uuid, key=key)
logger.debug("Model with key is being trained by another process")
return False

# This check is needed if there is backpressure in the pipeline
if _msg_train_ts and time.time() - float(_msg_train_ts) < retrain_freq * 60 * 60:
logger.debug(

Check warning on line 355 in numalogic/udfs/tools.py

View check run for this annotation

Codecov / codecov/patch

numalogic/udfs/tools.py#L355

Added line #L355 was not covered by tests
"Model was saved for the key in less than retrain_freq hrs, skipping training",
uuid=uuid,
key=key,
retrain_freq=retrain_freq,
)
return False
try:
self.client.hset(name=_key, key="_msg_read_ts", value=str(time.time()))
except Exception:
logger.exception(
"Problem while updating msg_read_ts information for the key",
uuid=uuid,
key=key,
)
logger.exception("Problem while updating msg_read_ts information for the key")
return False
logger.debug("Acknowledging request for Training for key", uuid=uuid, key=key)
logger.debug("Acknowledging request for Training for key")
return True

def ack_train(self, logger, key: KEYS, uuid: str) -> bool:
def ack_train(self, key: KEYS, uuid: str) -> bool:
"""
Acknowledge the train message is trained and saved. Return True when
_msg_train_ts is updated.
Expand All @@ -389,16 +378,13 @@ def ack_train(self, logger, key: KEYS, uuid: str) -> bool:
bool
"""
_key = self.__construct_train_key(key)
logger = _struct_log.bind(uuid=uuid, key=key)
try:
self.client.hset(name=_key, key="_msg_train_ts", value=str(time.time()))
except Exception:
logger.exception(
"Problem while updating msg_train_ts information for the key",
uuid=uuid,
key=key,
)
logger.exception("Problem while updating msg_train_ts information for the key")
return False
logger.debug("Acknowledging model saving complete for the key", uuid=uuid, key=key)
logger.debug("Acknowledging model saving complete for the key")
return True


Expand Down
8 changes: 3 additions & 5 deletions numalogic/udfs/trainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
retrain_freq_ts = _conf.numalogic_conf.trainer.retrain_freq_hr
retry_ts = _conf.numalogic_conf.trainer.retry_sec
if not self.train_msg_deduplicator.ack_read(
logger=logger,
key=[*payload.composite_keys, payload.pipeline_id],
uuid=payload.uuid,
retrain_freq=retrain_freq_ts,
Expand Down Expand Up @@ -217,7 +216,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
return Messages(Message.to_drop())

# Check if data is sufficient
if not self._is_data_sufficient(payload, df, logger):
if not self._is_data_sufficient(payload, df):
logger.warning(
"Insufficient data found",
uuid=payload.uuid,
Expand Down Expand Up @@ -276,7 +275,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
logger=logger,
)
if self.train_msg_deduplicator.ack_train(
logger=logger, key=[*payload.composite_keys, payload.pipeline_id], uuid=payload.uuid
key=[*payload.composite_keys, payload.pipeline_id], uuid=payload.uuid
):
logger.info("Model trained and saved successfully", uuid=payload.uuid)

Expand Down Expand Up @@ -353,13 +352,12 @@ def artifacts_to_save(
"Artifact saved with with versions", uuid=payload.uuid, version_dict=ver_dict
)

def _is_data_sufficient(self, payload: TrainerPayload, df: pd.DataFrame, logger) -> bool:
def _is_data_sufficient(self, payload: TrainerPayload, df: pd.DataFrame) -> bool:
_conf = self.get_ml_pipeline_conf(
config_id=payload.config_id, pipeline_id=payload.pipeline_id
)
if len(df) < _conf.numalogic_conf.trainer.min_train_size:
_ = self.train_msg_deduplicator.ack_insufficient_data(
logger=logger,
key=[*payload.composite_keys, payload.pipeline_id],
uuid=payload.uuid,
train_records=len(df),
Expand Down
19 changes: 5 additions & 14 deletions tests/udfs/test_rds_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from numalogic.connectors import RDSFetcher
from numalogic.connectors.utils.aws.config import RDSConnectionConfig
from numalogic.tools.exceptions import RDSFetcherError, ConfigNotFoundError
from numalogic.udfs._logger import configure_logger
from numalogic.udfs.tools import TrainMsgDeduplicator
from numalogic.udfs.trainer._rds import RDSTrainerUDF, build_query
import re
Expand Down Expand Up @@ -311,11 +310,7 @@ def test_trainer_do_not_train_3(mocker, udf1, datum_mock, payload):

keys = payload["composite_keys"]

logger = configure_logger()

TrainMsgDeduplicator(REDIS_CLIENT).ack_read(
logger=logger, key=[*keys, "pipeline1"], uuid="some-uuid"
)
TrainMsgDeduplicator(REDIS_CLIENT).ack_read(key=[*keys, "pipeline1"], uuid="some-uuid")
ts = datetime.strptime("2022-05-24 10:00:00", "%Y-%m-%d %H:%M:%S")
with freeze_time(ts + timedelta(minutes=15)):
udf1(keys, datum_mock)
Expand Down Expand Up @@ -463,27 +458,23 @@ def test_trainer_datafetcher_err_and_train(mocker, udf1, payload, datum_mock):

def test_TrainMsgDeduplicator_exception_1(mocker, caplog, payload):
caplog.set_level(logging.INFO)
logger = configure_logger()
mocker.patch("redis.Redis.hset", Mock(side_effect=RedisError))
train_dedup = TrainMsgDeduplicator(REDIS_CLIENT)
keys = payload["composite_keys"]
train_dedup.ack_read(logger=logger, key=[*keys, "pipeline1"], uuid="some-uuid")
train_dedup.ack_read(key=[*keys, "pipeline1"], uuid="some-uuid")
assert "RedisError" in caplog.text
train_dedup.ack_train(logger=logger, key=[*keys, "pipeline1"], uuid="some-uuid")
train_dedup.ack_train(key=[*keys, "pipeline1"], uuid="some-uuid")
assert "RedisError" in caplog.text
train_dedup.ack_insufficient_data(
logger=logger, key=[*keys, "pipeline1"], uuid="some-uuid", train_records=180
)
train_dedup.ack_insufficient_data(key=[*keys, "pipeline1"], uuid="some-uuid", train_records=180)
assert "RedisError" in caplog.text


def test_TrainMsgDeduplicator_exception_2(mocker, caplog, payload):
caplog.set_level(logging.INFO)
logger = configure_logger()
mocker.patch("redis.Redis.hset", Mock(side_effect=RedisError))
train_dedup = TrainMsgDeduplicator(REDIS_CLIENT)
keys = payload["composite_keys"]
train_dedup.ack_read(logger=logger, key=[*keys, "pipeline1"], uuid="some-uuid")
train_dedup.ack_read(key=[*keys, "pipeline1"], uuid="some-uuid")
assert "RedisError" in caplog.text


Expand Down
42 changes: 26 additions & 16 deletions tests/udfs/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
PrometheusFetcherError,
)
from numalogic.udfs import StreamConf, PipelineConf, MLPipelineConf
from numalogic.udfs._logger import configure_logger
from numalogic.udfs.tools import TrainMsgDeduplicator
from numalogic.udfs.trainer import DruidTrainerUDF, PromTrainerUDF

Expand Down Expand Up @@ -243,10 +242,9 @@ def test_trainer_do_not_train_2(self):
)
self.udf1(self.keys, self.datum)
ts = datetime.strptime("2022-05-24 10:00:00", "%Y-%m-%d %H:%M:%S")
logger = configure_logger()
with freeze_time(ts + timedelta(hours=25)):
TrainMsgDeduplicator(REDIS_CLIENT).ack_read(
logger=logger, key=[*self.keys, "pipeline1"], uuid="some-uuid"
key=[*self.keys, "pipeline1"], uuid="some-uuid"
)
with freeze_time(ts + timedelta(hours=25) + timedelta(minutes=15)):
self.udf1(self.keys, self.datum)
Expand All @@ -261,7 +259,6 @@ def test_trainer_do_not_train_2(self):

@patch.object(DruidFetcher, "fetch", Mock(return_value=mock_druid_fetch_data()))
def test_trainer_do_not_train_3(self):
logger = configure_logger()
self.udf1.register_conf(
"druid-config",
StreamConf(
Expand All @@ -282,9 +279,7 @@ def test_trainer_do_not_train_3(self):
}
),
)
TrainMsgDeduplicator(REDIS_CLIENT).ack_read(
logger=logger, key=[*self.keys, "pipeline1"], uuid="some-uuid"
)
TrainMsgDeduplicator(REDIS_CLIENT).ack_read(key=[*self.keys, "pipeline1"], uuid="some-uuid")
ts = datetime.strptime("2022-05-24 10:00:00", "%Y-%m-%d %H:%M:%S")
with freeze_time(ts + timedelta(minutes=15)):
self.udf1(self.keys, self.datum)
Expand Down Expand Up @@ -425,33 +420,48 @@ def test_trainer_datafetcher_err_and_train(self):
@patch("redis.Redis.hset", Mock(side_effect=RedisError))
def test_TrainMsgDeduplicator_exception_1(self):
train_dedup = TrainMsgDeduplicator(REDIS_CLIENT)
logger = configure_logger()
train_dedup.ack_read(logger=logger, key=[*self.keys, "pipeline1"], uuid="some-uuid")
train_dedup.ack_read(key=[*self.keys, "pipeline1"], uuid="some-uuid")
self.assertLogs("RedisError")
train_dedup.ack_train(logger=logger, key=[*self.keys, "pipeline1"], uuid="some-uuid")
train_dedup.ack_train(key=[*self.keys, "pipeline1"], uuid="some-uuid")
self.assertLogs("RedisError")
train_dedup.ack_insufficient_data(
logger=logger, key=[*self.keys, "pipeline1"], uuid="some-uuid", train_records=180
key=[*self.keys, "pipeline1"], uuid="some-uuid", train_records=180
)
self.assertLogs("RedisError")

@patch("redis.Redis.hset", Mock(side_effect=mock_druid_fetch_data()))
def test_TrainMsgDeduplicator_insufficent_data(self):
with self.assertLogs(level="DEBUG") as log:
logger = configure_logger()
train_dedup = TrainMsgDeduplicator(REDIS_CLIENT)
train_dedup.ack_insufficient_data(
logger=logger, key=[*self.keys, "pipeline1"], uuid="some-uuid", train_records=180
key=[*self.keys, "pipeline1"], uuid="some-uuid", train_records=180
)
self.assertLogs("Acknowledging insufficient data for the key", log.output[-1])
self.assertLogs(
"DEBUG:numalogic.udfs._logger:uuid='some-uuid' "
"event='Acknowledging insufficient data for the key' "
"key=['5984175597303660107', 'pipeline1'] "
"level='debug' timestamp='2024-06-14T19:08:23.610016Z'",
log.output[-1],
)

@patch("redis.Redis.hgetall", Mock(side_effect=RedisError))
def test_TrainMsgDeduplicator_exception_2(self):
train_dedup = TrainMsgDeduplicator(REDIS_CLIENT)
logger = configure_logger()
train_dedup.ack_read(logger=logger, key=[*self.keys, "pipeline1"], uuid="some-uuid")
train_dedup.ack_read(key=[*self.keys, "pipeline1"], uuid="some-uuid")
self.assertLogs("RedisError")

@patch("redis.Redis.hgetall", Mock(side_effect=RedisError))
def test_TrainMsgDeduplicator_test_debug_logs(self):
train_dedup = TrainMsgDeduplicator(REDIS_CLIENT)
with self.assertLogs(level="DEBUG") as log:
train_dedup.ack_read(key=[*self.keys, "pipeline1"], uuid="some-uuid")
self.assertLogs(
"DEBUG:numalogic.udfs._logger:uuid='some-uuid' "
"event='Acknowledging request for Training for key' key=['5984175597303660107', "
"'pipeline1'] level='debug' timestamp='2024-06-14T19:02:55.050604Z'",
log.output[-1],
)

def test_druid_from_config_1(self):
with self.assertLogs(level="WARN") as log:
self.udf1(self.keys, self.datum)
Expand Down

0 comments on commit b511956

Please sign in to comment.