diff --git a/tests/hpo/apps/main/conftest.py b/tests/hpo/apps/main/conftest.py index 34d5a7c2..d961650a 100644 --- a/tests/hpo/apps/main/conftest.py +++ b/tests/hpo/apps/main/conftest.py @@ -12,8 +12,7 @@ # Define type aliases for the database utility functions TrialCountFunc: TypeAlias = Callable[[Path, str], int] TrialValuesFunc: TypeAlias = Callable[[Path, str], list[float]] -DBUtils: TypeAlias = dict[Literal["get_trial_count", "get_trial_values"], - TrialCountFunc | TrialValuesFunc] +DBUtils: TypeAlias = dict[Literal["get_trial_count", "get_trial_values"], TrialCountFunc | TrialValuesFunc] ConfigModFunc: TypeAlias = Callable[[Path, str, int, str], Path] @@ -48,6 +47,7 @@ def temp_dir() -> Generator[Path, None, None]: @pytest.fixture def db_utils() -> DBUtils: """Fixture providing database utility functions""" + def get_trial_count(db_path: Path, study_name: str) -> int: """Get the number of trials from the SQLite database for a specific study""" if not db_path.exists(): @@ -87,15 +87,13 @@ def get_trial_values(db_path: Path, study_name: str) -> list[float]: conn.close() return values - return { - "get_trial_count": get_trial_count, - "get_trial_values": get_trial_values - } + return {"get_trial_count": get_trial_count, "get_trial_values": get_trial_values} @pytest.fixture def config_modifier() -> ConfigModFunc: """Fixture providing configuration file modification functionality""" + def modify_config(config_path: Path, study_name: str, n_trials: int, db_name: str) -> Path: """Modify config file with new study name and number of trials""" with open(config_path) as f: diff --git a/tests/hpo/apps/main/test_optimize.py b/tests/hpo/apps/main/test_optimize.py index 5b19a299..12b32ba6 100644 --- a/tests/hpo/apps/main/test_optimize.py +++ b/tests/hpo/apps/main/test_optimize.py @@ -5,11 +5,7 @@ from conftest import ConfigModFunc, DBUtils -def test_normal_execution( - temp_dir: Path, - db_utils: DBUtils, - config_modifier: ConfigModFunc -) -> None: +def test_normal_execution(temp_dir: Path, db_utils: DBUtils, config_modifier: ConfigModFunc) -> None: """Test normal execution without resume functionality""" from aiaccel.hpo.apps.optimize import main diff --git a/tests/hpo/apps/main/test_resume.py b/tests/hpo/apps/main/test_resume.py index 7551bc7b..44bfa82e 100644 --- a/tests/hpo/apps/main/test_resume.py +++ b/tests/hpo/apps/main/test_resume.py @@ -7,11 +7,7 @@ from conftest import ConfigModFunc, DBUtils, TrialValuesFunc -def test_optimization_consistency( - temp_dir: Path, - db_utils: DBUtils, - config_modifier: ConfigModFunc -) -> None: +def test_optimization_consistency(temp_dir: Path, db_utils: DBUtils, config_modifier: ConfigModFunc) -> None: """Test that split execution (resumable + resume) gives same results as normal execution.""" from aiaccel.hpo.apps.optimize import main @@ -59,11 +55,7 @@ def test_optimization_consistency( assert abs(normal_best - split_best) < 1e-6, f"Best values differ: normal={normal_best}, split={split_best}" -def test_resumable_execution( - temp_dir: Path, - db_utils: DBUtils, - config_modifier: ConfigModFunc -) -> None: +def test_resumable_execution(temp_dir: Path, db_utils: DBUtils, config_modifier: ConfigModFunc) -> None: """Test execution with `--resumable`""" from aiaccel.hpo.apps.optimize import main @@ -80,11 +72,7 @@ def test_resumable_execution( assert trial_count == 15 -def test_resume_execution( - temp_dir: Path, - db_utils: DBUtils, - config_modifier: ConfigModFunc -) -> None: +def test_resume_execution(temp_dir: Path, db_utils: DBUtils, config_modifier: ConfigModFunc) -> None: """Test the resume functionality of the optimization process.""" from aiaccel.hpo.apps.optimize import main