Skip to content

Commit

Permalink
clean up callback method
Browse files Browse the repository at this point in the history
  • Loading branch information
granthamtaylor committed Jan 10, 2025
1 parent 8bc3adb commit 5af7c60
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 42 deletions.
74 changes: 48 additions & 26 deletions plugins/flytekit-optuna/flytekitplugins/optuna/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from copy import copy, deepcopy
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Any, Callable, Optional, Union
from typing import Any, Awaitable, Callable, Concatenate, Optional, ParamSpec, Union

import optuna
from flytekit import PythonFunctionTask
Expand Down Expand Up @@ -46,14 +46,16 @@ class Category(Suggestion):

suggest = SimpleNamespace(float=Float, integer=Integer, category=Category)

P = ParamSpec("P")


@dataclass
class Optimizer:
objective: Union[PythonFunctionTask, PythonFunctionWorkflow]
concurrency: int
n_trials: int
study: Optional[optuna.Study] = None
callback: Optional[Callable[[optuna.Trial, dict[str, Any]], dict[str, Any]]] = None
objective: Optional[Union[PythonFunctionTask, PythonFunctionWorkflow]] = None
callback: Optional[Callable[Concatenate[optuna.Trial, P], Awaitable[Union[float, tuple[float, ...]]]]] = None

def __post_init__(self):
if self.study is None:
Expand All @@ -68,29 +70,45 @@ def __post_init__(self):
if not isinstance(self.study, optuna.Study):
raise ValueError("study must be an optuna.Study")

# check if the objective function returns the correct number of outputs
if isinstance(self.objective, PythonFunctionTask):
func = self.objective.task_function
elif isinstance(self.objective, PythonFunctionWorkflow):
func = self.objective._workflow_function
else:
raise ValueError("objective must be a PythonFunctionTask or PythonFunctionWorkflow")
if self.objective is not None:
if self.callback is not None:
raise ValueError("either objective or callback must be provided, not both")

signature = inspect.signature(func)
# check if the objective function returns the correct number of outputs
if isinstance(self.objective, PythonFunctionTask):
func = self.objective.task_function
elif isinstance(self.objective, PythonFunctionWorkflow):
func = self.objective._workflow_function
else:
raise ValueError("objective must be a PythonFunctionTask or PythonFunctionWorkflow")

if signature.return_annotation is float:
if len(self.study.directions) != 1:
raise ValueError("the study must have a single objective if objective returns a single float")
signature = inspect.signature(func)

elif isinstance(args := signature.return_annotation.__args__, tuple):
if len(args) != len(self.study.directions):
raise ValueError("objective must return the same number of directions in the study")
if signature.return_annotation is float:
if len(self.study.directions) != 1:
raise ValueError("the study must have a single objective if objective returns a single float")

if not all(arg is float for arg in args):
elif isinstance(args := signature.return_annotation.__args__, tuple):
if len(args) != len(self.study.directions):
raise ValueError("objective must return the same number of directions in the study")

if not all(arg is float for arg in args):
raise ValueError("objective function must return a float or tuple of floats")

else:
raise ValueError("objective function must return a float or tuple of floats")

else:
raise ValueError("objective function must return a float or tuple of floats")
if self.callback is None:
raise ValueError("either objective or callback must be provided")

if not callable(self.callback):
raise ValueError("callback must be a callable")

signature = inspect.signature(self.callback)

if "trial" not in signature.parameters:
raise ValueError("callback function must have a parameter called 'trial'")

async def __call__(self, **inputs: Any):
"""
Expand All @@ -113,14 +131,14 @@ async def spawn(self, semaphore: asyncio.Semaphore, inputs: dict[str, Any]):
# ask for a new trial
trial: optuna.Trial = self.study.ask()

if self.callback is not None:
inputs = self.callback(trial, inputs)
else:
inputs = process(trial, inputs)

try:
if self.callback is not None:
promise = self.callback(trial=trial, **inputs)
else:
promise = self.objective(**process(trial, inputs))

# schedule the trial
result: Union[float, tuple[float, ...]] = await self.objective(**inputs)
result: Union[float, tuple[float, ...]] = await promise

# tell the study the result
self.study.tell(trial, result, state=optuna.trial.TrialState.COMPLETE)
Expand All @@ -134,7 +152,11 @@ def process(trial: optuna.Trial, inputs: dict[str, Any], root: Optional[list[str
if root is None:
root = []

suggesters = {Float: trial.suggest_float, Integer: trial.suggest_int, Category: trial.suggest_categorical}
suggesters = {
Float: trial.suggest_float,
Integer: trial.suggest_int,
Category: trial.suggest_categorical,
}

for key, value in inputs.items():
path = copy(root) + [key]
Expand Down
32 changes: 16 additions & 16 deletions plugins/flytekit-optuna/tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def objective(x: float, y: int, z: int, power: int) -> float:

@fl.eager
async def train(concurrency: int, n_trials: int) -> float:
optimizer = Optimizer(objective, concurrency, n_trials)
optimizer = Optimizer(objective=objective, concurrency=concurrency, n_trials=n_trials)

await optimizer(
x=suggest.float(low=-10, high=10),
Expand Down Expand Up @@ -53,7 +53,7 @@ async def objective(suggestions: dict[str, Union[int, float]], z: int, power: in

@fl.eager
async def train(concurrency: int, n_trials: int) -> float:
optimizer = Optimizer(objective, concurrency, n_trials)
optimizer = Optimizer(objective=objective, concurrency=concurrency, n_trials=n_trials)

suggestions = {
"x": suggest.float(low=-10, high=10),
Expand Down Expand Up @@ -82,28 +82,28 @@ async def objective(letter: str, number: Union[float, int], other: str, fixed: s

return float(loss)

def callback(trial: optuna.Trial, fixed: str):

@fl.eager
async def train(concurrency: int, n_trials: int) -> float:
letter = trial.suggest_categorical("booster", ["A", "B", "BLAH"])

study = optuna.create_study(direction="maximize")
if letter == "A":
number = trial.suggest_int("number_A", 1, 10)
elif letter == "B":
number = trial.suggest_float("number_B", 10., 20.)
else:
number = 10

def callback(trial: optuna.Trial, inputs: dict[str, Any]) -> dict[str, Any]:
other = trial.suggest_categorical("other", ["Something", "another word", "a phrase"])

letter = trial.suggest_categorical("booster", ["A", "B", "BLAH"])
return objective(letter, number, other, fixed)

if letter == "A":
number = trial.suggest_int("number_A", 1, 10)
elif letter == "B":
number = trial.suggest_float("number_B", 10., 20.)
else:
number = 10

other = trial.suggest_categorical("other", ["Something", "another word", "a phrase"])
@fl.eager
async def train(concurrency: int, n_trials: int) -> float:

return dict(other=other, number=number, letter=letter, fixed=inputs["fixed"])
study = optuna.create_study(direction="maximize")

optimizer = Optimizer(objective, concurrency, n_trials, study=study, callback=callback)
optimizer = Optimizer(concurrency=concurrency, n_trials=n_trials, study=study, callback=callback)

await optimizer(fixed="hello!")

Expand Down

0 comments on commit 5af7c60

Please sign in to comment.