Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add suggestion bundle support #3041

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions plugins/flytekit-optuna/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import math

import flytekit as fl

from optimizer import Optimizer, suggest
from flytekitplugins.optuna import Optimizer, suggest

image = fl.ImageSpec(builder="union", packages=["flytekit==1.15.0b0", "optuna>=4.0.0"])
image = fl.ImageSpec(builder="union", packages=["flytekitplugins.optuna"])

@fl.task(container_image=image)
async def objective(x: float, y: int, z: int, power: int) -> float:
Expand Down Expand Up @@ -42,4 +42,3 @@ This plugin provides full feature parity to Optuna, including:

- This would synergize really well with Union Actors.
- This should also support workflows, but it currently does not.
- Add unit tests, of course.
86 changes: 63 additions & 23 deletions plugins/flytekit-optuna/flytekitplugins/optuna/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import inspect
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Any, Optional, Union
from typing import Optional, Union, Any
import inspect

import optuna

from flytekit import PythonFunctionTask
from flytekit.core.workflow import PythonFunctionWorkflow
from flytekit.exceptions.eager import EagerException
Expand Down Expand Up @@ -57,10 +58,10 @@ def __post_init__(self):
if self.study is None:
self.study = optuna.create_study()

if (not isinstance(self.concurrency, int)) or (self.concurrency < 0):
if (not isinstance(self.concurrency, int)) and (self.concurrency < 0):
raise ValueError("concurrency must be an integer greater than 0")

if (not isinstance(self.n_trials, int)) or (self.n_trials < 0):
if (not isinstance(self.n_trials, int)) and (self.n_trials < 0):
granthamtaylor marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("n_trials must be an integer greater than 0")

if not isinstance(self.study, optuna.Study):
Expand All @@ -72,56 +73,80 @@ def __post_init__(self):
elif isinstance(self.objective, PythonFunctionWorkflow):
func = self.objective._workflow_function
else:
raise ValueError("objective must be a PythonFunctionTask or PythonFunctionWorkflow")
raise ValueError(
"objective must be a PythonFunctionTask or PythonFunctionWorkflow"
)

signature = inspect.signature(func)

if signature.return_annotation is float:
if len(self.study.directions) != 1:
raise ValueError("the study must have a single objective if objective returns a single float")
raise ValueError(
"the study must have a single objective if objective returns a single float"
)

elif isinstance(args := signature.return_annotation.__args__, tuple):
if len(args) != len(self.study.directions):
raise ValueError("objective must return the same number of directions in the study")
raise ValueError(
"objective must return the same number of directions in the study"
)

if not all(arg is float for arg in args):
raise ValueError("objective function must return a float or tuple of floats")
raise ValueError(
"objective function must return a float or tuple of floats"
)

else:
raise ValueError("objective function must return a float or tuple of floats")
raise ValueError(
"objective function must return a float or tuple of floats"
)

async def __call__(self, **inputs: Any):
async def __call__(
self, suggestions: Optional[dict[str, Suggestion]] = None, /, **inputs: Any
):
"""
Asynchronously executes the objective function remotely.
Parameters:
**inputs: inputs to objective function
suggestions (Optional[dict[str, Suggestion]]): bundled suggestions to objective function
**inputs: other inputs to objective function
"""

if suggestions is not None:
if not isinstance(suggestions, dict):
raise ValueError("suggestions must be a dict[str, Suggestion]")

for key, value in suggestions.items():
if not isinstance(key, str):
raise ValueError(f"suggestion key must be a string, got {type(key)}")
if not isinstance(value, Suggestion):
raise ValueError(f"suggestion must be of type {type(value)}")

# create semaphore to manage concurrency
semaphore = asyncio.Semaphore(self.concurrency)

# create list of async trials
trials = [self.spawn(semaphore, **inputs) for _ in range(self.n_trials)]
trials = [
self.spawn(semaphore, suggestions, **inputs) for _ in range(self.n_trials)
]
granthamtaylor marked this conversation as resolved.
Show resolved Hide resolved

# await all trials to complete
await asyncio.gather(*trials)

async def spawn(self, semaphore: asyncio.Semaphore, **inputs: Any):
async def spawn(
self,
semaphore: asyncio.Semaphore,
suggestions: Optional[dict[str, Suggestion]] = None,
/,
**inputs: Any,
):
async with semaphore:
# ask for a new trial
trial: optuna.Trial = self.study.ask()

suggesters = {
Float: trial.suggest_float,
Integer: trial.suggest_int,
Category: trial.suggest_categorical,
}
if suggestions is not None:
inputs["suggestions"] = self.suggest(trial, suggestions)

# suggest inputs for the trial
for key, value in inputs.items():
if isinstance(value, Suggestion):
suggester = suggesters[type(value)]
inputs[key] = suggester(name=key, **vars(value))
inputs = self.suggest(trial, inputs)

try:
# schedule the trial
Expand All @@ -133,3 +158,18 @@ async def spawn(self, semaphore: asyncio.Semaphore, **inputs: Any):
# if the trial fails, tell the study
except EagerException:
self.study.tell(trial, state=optuna.trial.TrialState.FAIL)

@staticmethod
def suggest(trial: optuna.Trial, inputs: dict[str, Any]) -> dict[str, Any]:
suggesters = {
Float: trial.suggest_float,
Integer: trial.suggest_int,
Category: trial.suggest_categorical,
}

for key, value in inputs.items():
if isinstance(inputs[key], Suggestion):
suggester = suggesters[type(value)]
inputs[key] = suggester(name=key, **vars(value))

return inputs
33 changes: 32 additions & 1 deletion plugins/flytekit-optuna/tests/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
import math
from typing import Union

import flytekit as fl
from flytekitplugins.optuna import Optimizer, suggest


image = fl.ImageSpec(builder="union", packages=["flytekit==1.15.0b0", "optuna>=4.0.0"])
image = fl.ImageSpec(builder="union", packages=["flytekit>=1.15.0", "optuna>=4.0.0"])

@fl.task(container_image=image)
async def objective(x: float, y: int, z: int, power: int) -> float:
Expand All @@ -30,3 +31,33 @@ def test_local_exec():
loss = asyncio.run(train(concurrency=2, n_trials=10))

assert isinstance(loss, float)


@fl.task(container_image=image)
async def bundled_objective(suggestions: dict[str, Union[int, float]], z: int, power: int) -> float:
granthamtaylor marked this conversation as resolved.
Show resolved Hide resolved
granthamtaylor marked this conversation as resolved.
Show resolved Hide resolved

# building out a large set of typed inputs is exhausting, so we can just use a dict

x, y = suggestions["x"], suggestions["y"]

return math.log((((x - 5) ** 2) + (y + 4) ** 4 + (3 * z - 3) ** 2)) ** power


@fl.eager(container_image=image)
async def train(concurrency: int, n_trials: int) -> float:
optimizer = Optimizer(objective, concurrency, n_trials)

suggestions = {
"x": suggest.float(low=-10, high=10),
"y": suggest.integer(low=-10, high=10),
}

await optimizer(suggestions, z=suggest.category([-5, 0, 3, 6, 9]), power=2)

return optimizer.study.best_value

def test_bundled_local_exec():

loss = asyncio.run(train(concurrency=2, n_trials=10))

assert isinstance(loss, float)
Loading