Skip to content

Commit

Permalink
ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
aramoto99 committed Jan 7, 2025
1 parent 9ae2168 commit b04b600
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 26 deletions.
10 changes: 4 additions & 6 deletions tests/hpo/apps/main/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# Define type aliases for the database utility functions
TrialCountFunc: TypeAlias = Callable[[Path, str], int]
TrialValuesFunc: TypeAlias = Callable[[Path, str], list[float]]
DBUtils: TypeAlias = dict[Literal["get_trial_count", "get_trial_values"],
TrialCountFunc | TrialValuesFunc]
DBUtils: TypeAlias = dict[Literal["get_trial_count", "get_trial_values"], TrialCountFunc | TrialValuesFunc]
ConfigModFunc: TypeAlias = Callable[[Path, str, int, str], Path]


Expand Down Expand Up @@ -48,6 +47,7 @@ def temp_dir() -> Generator[Path, None, None]:
@pytest.fixture
def db_utils() -> DBUtils:
"""Fixture providing database utility functions"""

def get_trial_count(db_path: Path, study_name: str) -> int:
"""Get the number of trials from the SQLite database for a specific study"""
if not db_path.exists():
Expand Down Expand Up @@ -87,15 +87,13 @@ def get_trial_values(db_path: Path, study_name: str) -> list[float]:
conn.close()
return values

return {
"get_trial_count": get_trial_count,
"get_trial_values": get_trial_values
}
return {"get_trial_count": get_trial_count, "get_trial_values": get_trial_values}


@pytest.fixture
def config_modifier() -> ConfigModFunc:
"""Fixture providing configuration file modification functionality"""

def modify_config(config_path: Path, study_name: str, n_trials: int, db_name: str) -> Path:
"""Modify config file with new study name and number of trials"""
with open(config_path) as f:
Expand Down
6 changes: 1 addition & 5 deletions tests/hpo/apps/main/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
from conftest import ConfigModFunc, DBUtils


def test_normal_execution(
temp_dir: Path,
db_utils: DBUtils,
config_modifier: ConfigModFunc
) -> None:
def test_normal_execution(temp_dir: Path, db_utils: DBUtils, config_modifier: ConfigModFunc) -> None:
"""Test normal execution without resume functionality"""
from aiaccel.hpo.apps.optimize import main

Expand Down
18 changes: 3 additions & 15 deletions tests/hpo/apps/main/test_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
from conftest import ConfigModFunc, DBUtils, TrialValuesFunc


def test_optimization_consistency(
temp_dir: Path,
db_utils: DBUtils,
config_modifier: ConfigModFunc
) -> None:
def test_optimization_consistency(temp_dir: Path, db_utils: DBUtils, config_modifier: ConfigModFunc) -> None:
"""Test that split execution (resumable + resume) gives same results as normal execution."""
from aiaccel.hpo.apps.optimize import main

Expand Down Expand Up @@ -59,11 +55,7 @@ def test_optimization_consistency(
assert abs(normal_best - split_best) < 1e-6, f"Best values differ: normal={normal_best}, split={split_best}"


def test_resumable_execution(
temp_dir: Path,
db_utils: DBUtils,
config_modifier: ConfigModFunc
) -> None:
def test_resumable_execution(temp_dir: Path, db_utils: DBUtils, config_modifier: ConfigModFunc) -> None:
"""Test execution with `--resumable`"""
from aiaccel.hpo.apps.optimize import main

Expand All @@ -80,11 +72,7 @@ def test_resumable_execution(
assert trial_count == 15


def test_resume_execution(
temp_dir: Path,
db_utils: DBUtils,
config_modifier: ConfigModFunc
) -> None:
def test_resume_execution(temp_dir: Path, db_utils: DBUtils, config_modifier: ConfigModFunc) -> None:
"""Test the resume functionality of the optimization process."""
from aiaccel.hpo.apps.optimize import main

Expand Down

0 comments on commit b04b600

Please sign in to comment.