diff --git a/src/predictions/profiles_mlcorelib/trainers/ClassificationTrainer.py b/src/predictions/profiles_mlcorelib/trainers/ClassificationTrainer.py index 2a2dc6204..f81ff4b21 100644 --- a/src/predictions/profiles_mlcorelib/trainers/ClassificationTrainer.py +++ b/src/predictions/profiles_mlcorelib/trainers/ClassificationTrainer.py @@ -222,6 +222,7 @@ def prepare_training_summary( training_summary = { "timestamp": model_timestamp, "data": { + "model": model_results["model_class_name"], "metrics": model_results["metrics"], "threshold": model_results["prob_th"], }, diff --git a/src/predictions/profiles_mlcorelib/trainers/RegressionTrainer.py b/src/predictions/profiles_mlcorelib/trainers/RegressionTrainer.py index 35daaba8c..9d7dab9cd 100644 --- a/src/predictions/profiles_mlcorelib/trainers/RegressionTrainer.py +++ b/src/predictions/profiles_mlcorelib/trainers/RegressionTrainer.py @@ -161,7 +161,10 @@ def prepare_training_summary( ) -> dict: training_summary = { "timestamp": model_timestamp, - "data": {"metrics": model_results["metrics"]}, + "data": { + "model": model_results["model_class_name"], + "metrics": model_results["metrics"], + }, } return training_summary diff --git a/tests/unit/MLTrainer.py b/tests/unit/MLTrainer.py index ea9f5ce2d..c7dd30aeb 100644 --- a/tests/unit/MLTrainer.py +++ b/tests/unit/MLTrainer.py @@ -43,16 +43,26 @@ class TestClassificationTrainer(unittest.TestCase): def test_prepare_training_summary(self): config = build_trainer_config() trainer = TrainerFactory.create(config) + model_class_name = "MODEL_NAME" metrics = {"test": {}, "train": {}, "val": {}} timestamp = "2023-11-08" threshold = 0.62 result = trainer.prepare_training_summary( - {"metrics": metrics, "prob_th": threshold}, timestamp + { + "model_class_name": model_class_name, + "metrics": metrics, + "prob_th": threshold, + }, + timestamp, ) self.assertEqual( result, { - "data": {"metrics": metrics, "threshold": threshold}, + "data": { + "model": model_class_name, + "metrics": metrics, + "threshold": threshold, + }, "timestamp": timestamp, }, ) @@ -177,10 +187,19 @@ class TestRegressionTrainer(unittest.TestCase): def test_prepare_training_summary(self): config = build_trainer_config(task="regression") trainer = TrainerFactory.create(config) + model_class_name = "MODEL_NAME" metrics = {"test": {}, "train": {}, "val": {}} timestamp = "2023-11-08" - result = trainer.prepare_training_summary({"metrics": metrics}, timestamp) - self.assertEqual(result, {"data": {"metrics": metrics}, "timestamp": timestamp}) + result = trainer.prepare_training_summary( + {"model_class_name": model_class_name, "metrics": metrics}, timestamp + ) + self.assertEqual( + result, + { + "data": {"model": model_class_name, "metrics": metrics}, + "timestamp": timestamp, + }, + ) def test_validate_data(self): config = build_trainer_config(task="regression")