Skip to content

Commit

Permalink
passing model name in training_summary json (#414)
Browse files Browse the repository at this point in the history
  • Loading branch information
joker2411 authored Aug 13, 2024
1 parent fbcb38a commit a71de41
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def prepare_training_summary(
training_summary = {
"timestamp": model_timestamp,
"data": {
"model": model_results["model_class_name"],
"metrics": model_results["metrics"],
"threshold": model_results["prob_th"],
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ def prepare_training_summary(
) -> dict:
training_summary = {
"timestamp": model_timestamp,
"data": {"metrics": model_results["metrics"]},
"data": {
"model": model_results["model_class_name"],
"metrics": model_results["metrics"],
},
}
return training_summary

Expand Down
27 changes: 23 additions & 4 deletions tests/unit/MLTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,26 @@ class TestClassificationTrainer(unittest.TestCase):
def test_prepare_training_summary(self):
config = build_trainer_config()
trainer = TrainerFactory.create(config)
model_class_name = "MODEL_NAME"
metrics = {"test": {}, "train": {}, "val": {}}
timestamp = "2023-11-08"
threshold = 0.62
result = trainer.prepare_training_summary(
{"metrics": metrics, "prob_th": threshold}, timestamp
{
"model_class_name": model_class_name,
"metrics": metrics,
"prob_th": threshold,
},
timestamp,
)
self.assertEqual(
result,
{
"data": {"metrics": metrics, "threshold": threshold},
"data": {
"model": model_class_name,
"metrics": metrics,
"threshold": threshold,
},
"timestamp": timestamp,
},
)
Expand Down Expand Up @@ -177,10 +187,19 @@ class TestRegressionTrainer(unittest.TestCase):
def test_prepare_training_summary(self):
config = build_trainer_config(task="regression")
trainer = TrainerFactory.create(config)
model_class_name = "MODEL_NAME"
metrics = {"test": {}, "train": {}, "val": {}}
timestamp = "2023-11-08"
result = trainer.prepare_training_summary({"metrics": metrics}, timestamp)
self.assertEqual(result, {"data": {"metrics": metrics}, "timestamp": timestamp})
result = trainer.prepare_training_summary(
{"model_class_name": model_class_name, "metrics": metrics}, timestamp
)
self.assertEqual(
result,
{
"data": {"model": model_class_name, "metrics": metrics},
"timestamp": timestamp,
},
)

def test_validate_data(self):
config = build_trainer_config(task="regression")
Expand Down

0 comments on commit a71de41

Please sign in to comment.