Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
code assumed that hyperparameter optimisation should select the trial that "maximize" its loss (e.g. correlation). made this to be interchangeable.
  • Loading branch information
fmunzlin authored Apr 24, 2024
1 parent 698cbce commit 605cda7
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion deeprvat/deeprvat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,12 +1172,17 @@ def best_training_run(
:returns: None
"""

study = optuna.load_study(
study_name=Path(hpopt_db).stem, storage=f"sqlite:///{hpopt_db}"
)

trials = study.trials_dataframe().query('state == "COMPLETE"')
best_trial = trials.sort_values("value", ascending=False).iloc[0]
with open("config.yaml") as f:
config = yaml.safe_load(f)
ascending = False if config["hyperparameter_optimization"]["direction"] == "maximize" else True
f.close()
best_trial = trials.sort_values("value", ascending=ascending).iloc[0]
best_trial_id = best_trial["user_attrs_user_id"]

logger.info(f"Best trial:\n{best_trial}")
Expand Down

0 comments on commit 605cda7

Please sign in to comment.