diff --git a/plugins/flytekit-optuna/flytekitplugins/optuna/optimizer.py b/plugins/flytekit-optuna/flytekitplugins/optuna/optimizer.py index 4786998e46..0b7c6f49c2 100644 --- a/plugins/flytekit-optuna/flytekitplugins/optuna/optimizer.py +++ b/plugins/flytekit-optuna/flytekitplugins/optuna/optimizer.py @@ -3,7 +3,7 @@ from copy import copy, deepcopy from dataclasses import dataclass from types import SimpleNamespace -from typing import Any, Callable, Optional, Union +from typing import Any, Awaitable, Callable, Concatenate, Optional, ParamSpec, Union import optuna from flytekit import PythonFunctionTask @@ -46,14 +46,16 @@ class Category(Suggestion): suggest = SimpleNamespace(float=Float, integer=Integer, category=Category) +P = ParamSpec("P") + @dataclass class Optimizer: - objective: Union[PythonFunctionTask, PythonFunctionWorkflow] concurrency: int n_trials: int study: Optional[optuna.Study] = None - callback: Optional[Callable[[optuna.Trial, dict[str, Any]], dict[str, Any]]] = None + objective: Optional[Union[PythonFunctionTask, PythonFunctionWorkflow]] = None + callback: Optional[Callable[Concatenate[optuna.Trial, P], Awaitable[Union[float, tuple[float, ...]]]]] = None def __post_init__(self): if self.study is None: @@ -68,29 +70,45 @@ def __post_init__(self): if not isinstance(self.study, optuna.Study): raise ValueError("study must be an optuna.Study") - # check if the objective function returns the correct number of outputs - if isinstance(self.objective, PythonFunctionTask): - func = self.objective.task_function - elif isinstance(self.objective, PythonFunctionWorkflow): - func = self.objective._workflow_function - else: - raise ValueError("objective must be a PythonFunctionTask or PythonFunctionWorkflow") + if self.objective is not None: + if self.callback is not None: + raise ValueError("either objective or callback must be provided, not both") - signature = inspect.signature(func) + # check if the objective function returns the correct number of outputs + if isinstance(self.objective, PythonFunctionTask): + func = self.objective.task_function + elif isinstance(self.objective, PythonFunctionWorkflow): + func = self.objective._workflow_function + else: + raise ValueError("objective must be a PythonFunctionTask or PythonFunctionWorkflow") - if signature.return_annotation is float: - if len(self.study.directions) != 1: - raise ValueError("the study must have a single objective if objective returns a single float") + signature = inspect.signature(func) - elif isinstance(args := signature.return_annotation.__args__, tuple): - if len(args) != len(self.study.directions): - raise ValueError("objective must return the same number of directions in the study") + if signature.return_annotation is float: + if len(self.study.directions) != 1: + raise ValueError("the study must have a single objective if objective returns a single float") - if not all(arg is float for arg in args): + elif isinstance(args := signature.return_annotation.__args__, tuple): + if len(args) != len(self.study.directions): + raise ValueError("objective must return the same number of directions in the study") + + if not all(arg is float for arg in args): + raise ValueError("objective function must return a float or tuple of floats") + + else: raise ValueError("objective function must return a float or tuple of floats") else: - raise ValueError("objective function must return a float or tuple of floats") + if self.callback is None: + raise ValueError("either objective or callback must be provided") + + if not callable(self.callback): + raise ValueError("callback must be a callable") + + signature = inspect.signature(self.callback) + + if "trial" not in signature.parameters: + raise ValueError("callback function must have a parameter called 'trial'") async def __call__(self, **inputs: Any): """ @@ -113,14 +131,14 @@ async def spawn(self, semaphore: asyncio.Semaphore, inputs: dict[str, Any]): # ask for a new trial trial: optuna.Trial = self.study.ask() - if self.callback is not None: - inputs = self.callback(trial, inputs) - else: - inputs = process(trial, inputs) - try: + if self.callback is not None: + promise = self.callback(trial=trial, **inputs) + else: + promise = self.objective(**process(trial, inputs)) + # schedule the trial - result: Union[float, tuple[float, ...]] = await self.objective(**inputs) + result: Union[float, tuple[float, ...]] = await promise # tell the study the result self.study.tell(trial, result, state=optuna.trial.TrialState.COMPLETE) @@ -134,7 +152,11 @@ def process(trial: optuna.Trial, inputs: dict[str, Any], root: Optional[list[str if root is None: root = [] - suggesters = {Float: trial.suggest_float, Integer: trial.suggest_int, Category: trial.suggest_categorical} + suggesters = { + Float: trial.suggest_float, + Integer: trial.suggest_int, + Category: trial.suggest_categorical, + } for key, value in inputs.items(): path = copy(root) + [key] diff --git a/plugins/flytekit-optuna/tests/test_optimizer.py b/plugins/flytekit-optuna/tests/test_optimizer.py index a7c101a558..7175a5e4d7 100644 --- a/plugins/flytekit-optuna/tests/test_optimizer.py +++ b/plugins/flytekit-optuna/tests/test_optimizer.py @@ -22,7 +22,7 @@ async def objective(x: float, y: int, z: int, power: int) -> float: @fl.eager async def train(concurrency: int, n_trials: int) -> float: - optimizer = Optimizer(objective, concurrency, n_trials) + optimizer = Optimizer(objective=objective, concurrency=concurrency, n_trials=n_trials) await optimizer( x=suggest.float(low=-10, high=10), @@ -53,7 +53,7 @@ async def objective(suggestions: dict[str, Union[int, float]], z: int, power: in @fl.eager async def train(concurrency: int, n_trials: int) -> float: - optimizer = Optimizer(objective, concurrency, n_trials) + optimizer = Optimizer(objective=objective, concurrency=concurrency, n_trials=n_trials) suggestions = { "x": suggest.float(low=-10, high=10), @@ -82,28 +82,28 @@ async def objective(letter: str, number: Union[float, int], other: str, fixed: s return float(loss) + def callback(trial: optuna.Trial, fixed: str): - @fl.eager - async def train(concurrency: int, n_trials: int) -> float: + letter = trial.suggest_categorical("booster", ["A", "B", "BLAH"]) - study = optuna.create_study(direction="maximize") + if letter == "A": + number = trial.suggest_int("number_A", 1, 10) + elif letter == "B": + number = trial.suggest_float("number_B", 10., 20.) + else: + number = 10 - def callback(trial: optuna.Trial, inputs: dict[str, Any]) -> dict[str, Any]: + other = trial.suggest_categorical("other", ["Something", "another word", "a phrase"]) - letter = trial.suggest_categorical("booster", ["A", "B", "BLAH"]) + return objective(letter, number, other, fixed) - if letter == "A": - number = trial.suggest_int("number_A", 1, 10) - elif letter == "B": - number = trial.suggest_float("number_B", 10., 20.) - else: - number = 10 - other = trial.suggest_categorical("other", ["Something", "another word", "a phrase"]) + @fl.eager + async def train(concurrency: int, n_trials: int) -> float: - return dict(other=other, number=number, letter=letter, fixed=inputs["fixed"]) + study = optuna.create_study(direction="maximize") - optimizer = Optimizer(objective, concurrency, n_trials, study=study, callback=callback) + optimizer = Optimizer(concurrency=concurrency, n_trials=n_trials, study=study, callback=callback) await optimizer(fixed="hello!")