diff --git a/deeprootgen/pipeline/experiment.py b/deeprootgen/pipeline/experiment.py index f1b8a99..f82817e 100644 --- a/deeprootgen/pipeline/experiment.py +++ b/deeprootgen/pipeline/experiment.py @@ -21,6 +21,7 @@ from ..data_model import RootSimulationModel from ..model import RootSystemSimulation +from ..statistics import get_summary_statistic_func, get_summary_statistics OUT_DIR = osp.join("/app", "outputs") @@ -196,6 +197,28 @@ def log_simulation( metric_df.to_csv(outfile, index=False) mlflow.log_artifact(outfile) + statistic_names = [] + statistic_values = [] + summary_statistics = get_summary_statistics() + for summary_statistic in summary_statistics: + statistic_name = summary_statistic["value"] + statistic_func = get_summary_statistic_func(statistic_name) + statistic_instance = statistic_func() + statistic_value = statistic_instance.calculate(node_df) + + if isinstance(statistic_value, tuple): + continue + statistic_names.append(statistic_name) + statistic_values.append(statistic_value) + mlflow.log_metric(statistic_name, statistic_value) + + statistic_df = pd.DataFrame( + {"statistic_name": statistic_names, "statistic_value": statistic_values} + ) + outfile = osp.join(OUT_DIR, f"{time_now}-{task}_summary_statistics.csv") + statistic_df.to_csv(outfile, index=False) + mlflow.log_artifact(outfile) + def load_form_parameters( list_of_contents: list, list_of_names: list, form_name: str, task: str = ""